Python __call__ 魔法方法详解
什么是 __call__?
__call__ 是 Python 的魔法方法(Magic Method),它让对象实例可以像函数一样被调用。
基本示例
class Adder:
def __init__(self, n):
self.n = n
def __call__(self, x):
return x + self.n
# 创建对象
add_5 = Adder(5)
# 像函数一样调用对象!
result = add_5(10) # 调用 __call__(10)
print(result) # 15
# 检查是否可调用
print(callable(add_5)) # True
当你调用 add_5(10) 时,Python 实际上调用的是 add_5.__call__(10)。
为什么使用 __call__?
使用 __call__ 的主要优势:
- 有状态的函数:对象可以保存状态,而普通函数需要使用全局变量或闭包
- 更清晰的接口:构造函数提供了清晰的参数配置接口
- 面向对象设计:可以利用继承和多态
- 框架集成:许多框架(如 PyTorch)使用这种模式
PyTorch 为什么总用 net(x) 而不是 net.forward(x)?
这是 __call__ 最重要的应用场景之一。在 PyTorch 中,你总是看到这样的代码:
import torch
import torch.nn as nn
class MyNetwork(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
net = MyNetwork()
x = torch.randn(2, 10)
# ✅ 推荐:使用 net(x)
y = net(x)
# ❌ 不推荐:直接调用 forward
# y = net.forward(x)
nn.Module 的 __call__ 做了什么?
让我们看看 PyTorch 源码中 nn.Module 的简化版实现:
class Module:
def __call__(self, *args, **kwargs):
# 1. 调用前向钩子(pre-forward hooks)
for hook in self._forward_pre_hooks.values():
result = hook(self, args)
if result is not None:
args = result
# 2. 执行 forward 方法
result = self.forward(*args, **kwargs)
# 3. 调用后向钩子(forward hooks)
for hook in self._forward_hooks.values():
hook_result = hook(self, args, result)
if hook_result is not None:
result = hook_result
# 4. 返回结果
return result
def forward(self, *args, **kwargs):
raise NotImplementedError
为什么必须用 net(x) 而不是 net.forward(x)?
使用 net(x) 而不是 net.forward(x) 的原因:
1. 钩子函数(Hooks)会被跳过
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
net = Net()
x = torch.randn(2, 10)
# 注册一个钩子函数
def print_output(module, input, output):
print(f"Output shape: {output.shape}")
net.register_forward_hook(print_output)
# ✅ 使用 net(x):钩子会被调用
y = net(x) # 输出:Output shape: torch.Size([2, 5])
# ❌ 使用 forward(x):钩子被跳过
y = net.forward(x) # 什么都不输出!
2. 梯度记录可能出问题
# 某些情况下,直接调用 forward 可能影响梯度计算
# PyTorch 内部依赖 __call__ 来正确处理梯度
3. 训练/评估模式切换
# 虽然 train()/eval() 切换是通过设置 self.training 实现的
# 但 __call__ 确保了所有子模块都能正确响应模式切换
net.train()
y = net(x) # __call__ 确保所有层都知道当前是训练模式
net.eval()
y = net(x) # __call__ 确保所有层都知道当前是评估模式
4. 调试和性能分析
PyTorch 的 profiler 和调试工具依赖 __call__ 来追踪网络执行:
import torch.profiler as profiler
with profiler.profile() as prof:
y = net(x) # ✅ 可以被正确追踪
# y = net.forward(x) # ❌ 追踪信息不完整
print(prof.key_averages().table())
实际例子:钩子的威力
import torch
import torch.nn as nn
class MyNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 5)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
net = MyNet()
# 注册钩子来查看中间层输出
activations = {}
def get_activation(name):
def hook(module, input, output):
activations[name] = output.detach()
return hook
# 为每一层注册钩子
net.fc1.register_forward_hook(get_activation('fc1'))
net.fc2.register_forward_hook(get_activation('fc2'))
x = torch.randn(2, 10)
# ✅ 使用 net(x):钩子正常工作
y = net(x)
print("FC1 output shape:", activations['fc1'].shape) # torch.Size([2, 20])
print("FC2 output shape:", activations['fc2'].shape) # torch.Size([2, 5])
# ❌ 使用 forward(x):activations 字典是空的
activations.clear()
y = net.forward(x)
print("Activations after forward():", activations) # {}