手写带掩码的自注意力机制
GPT 等大模型之所以强大,核心在于它们如何"理解"上下文。而这一切的基石,就藏在不到 20 行的 PyTorch 代码中。
在这篇文章中,我们将逐行拆解 Transformer 的带掩码自注意力机制(Masked Self-Attention)。我们将揭示两个秘密:
- 模型如何计算词与词之间的关联?
- 为何模型不需要循环,就能一次性学会"看前2个词预测第3个"和"看前10个词预测第11个"?
1. 数据的准备:定义时空
首先,我们需要构建输入数据的"舞台"。
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337) # 保证结果可复现
B, T, C = 4, 8, 32 # Batch, Time, Channels
x = torch.randn(B, T, C)
- B (Batch Size) = 4: 我们一次并行处理 4 个独立的句子。
- T (Time Steps) = 8: 每个句子包含 8 个 token(比如 8 个字)。
- C (Channels) = 32: 每个字被编码为一个 32 维的向量。
- x: 输入张量,代表了这 4 句话中每个字的原始特征。
2. 变换:寻找信息的 Q, K, V
注意力机制的第一步,是将输入特征映射到三个不同的空间。这就像数据库查询:
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
v = value(x) # (B, T, 16) -> 实际携带信息的向量
- Query (q): 当前字发出的"查询请求"(我想找什么样的信息?)。
- Key (k): 当前字展示的"标签特征"(我包含什么样的信息?)。
- Value (v): 如果匹配成功,当前字实际传递的内容。