Skip to main content

手写带掩码的自注意力机制

GPT 等大模型之所以强大,核心在于它们如何"理解"上下文。而这一切的基石,就藏在不到 20 行的 PyTorch 代码中。

在这篇文章中,我们将逐行拆解 Transformer 的带掩码自注意力机制(Masked Self-Attention)。我们将揭示两个秘密:

  1. 模型如何计算词与词之间的关联?
  2. 为何模型不需要循环,就能一次性学会"看前2个词预测第3个"和"看前10个词预测第11个"?

1. 数据的准备:定义时空

首先,我们需要构建输入数据的"舞台"。

import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337) # 保证结果可复现
B, T, C = 4, 8, 32 # Batch, Time, Channels
x = torch.randn(B, T, C)
  • B (Batch Size) = 4: 我们一次并行处理 4 个独立的句子。
  • T (Time Steps) = 8: 每个句子包含 8 个 token(比如 8 个字)。
  • C (Channels) = 32: 每个字被编码为一个 32 维的向量。
  • x: 输入张量,代表了这 4 句话中每个字的原始特征。

2. 变换:寻找信息的 Q, K, V

注意力机制的第一步,是将输入特征映射到三个不同的空间。这就像数据库查询:

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
v = value(x) # (B, T, 16) -> 实际携带信息的向量
  • Query (q): 当前字发出的"查询请求"(我想找什么样的信息?)。
  • Key (k): 当前字展示的"标签特征"(我包含什么样的信息?)。
  • Value (v): 如果匹配成功,当前字实际传递的内容。

3. 注意力分数:计算相关性

接下来,我们要弄清楚句子中谁和谁有关联

wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) -> (B, T, T)

这里进行了一次矩阵乘法。结果 wei 是一个 (B,T,T)(B, T, T) 的矩阵。 在 T×TT \times T (8x8) 的方阵中,wei[row, col] 代表了第 row 个字对第 col 个字的关注程度。数值越大,相关性越高。


4. 核心魔法:掩码 (Masking) 与并行训练

这是 GPT (Decoder-only) 架构最精妙的地方。

在训练时,我们手里有完整的句子(比如"我爱吃苹果")。但是,我们希望模型学会预测

  • 当模型读到"爱"时,它只能基于"我、爱"的信息去预测"吃"。
  • 绝对不能偷看后面的"苹果"。

我们通过一个下三角掩码来实现这种"因果限制",同时利用矩阵运算实现了训练并行化

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

为什么说这是"一次运算,全量训练"?

你可能会问:"不需要写 For 循环吗?不需要先算第一个词,再算第二个词吗?" 不需要。 所有的预测任务都在这个 T×TT \times T 矩阵中一次性完成了。

让我们看看 wei 矩阵(以及随后的输出 out)的每一行代表什么:

  • 第 0 行 (对应第 1 个字):

    • Mask 强制让它只能看见第 0 列。
    • 含义: 模型只利用了第 1 个字的信息。
    • 任务: 它的输出向量用于预测第 2 个字
  • 第 2 行 (对应第 3 个字):

    • Mask 允许它看见第 0, 1, 2 列。
    • 含义: 模型融合了前 3 个字的信息。
    • 任务: 它的输出向量用于预测第 4 个字
  • 第 10 行 (如果有的话):

    • Mask 允许它看见 0~10 列。
    • 含义: 模型融合了前 11 个字的信息。
    • 任务: 它的输出向量用于预测第 12 个字

这就是 Transformer 训练效率极高的秘密: 它像开了"上帝视角",在一个矩阵运算中,同时模拟了时间轴上的所有时刻。第 ii 行就是在模拟"如果我现在站在第 ii 个字的位置,回顾过去,我该如何预测下一个字"。

Softmax 操作后,被 Mask 的位置变成了 00,合法的位置变成了概率权重。


5. 聚合:生成新表示

最后,根据计算出的权重,聚合 Value 向量。

out = wei @ v # (B, T, T) @ (B, T, 16) -> (B, T, 16)
  • out 的形状是 (B,T,16)(B, T, 16)
  • out 中的每一个向量,不再是原本单一的词向量,而是融合了它之前所有相关词信息的、上下文感知的深度表示

总结

这段简短的代码完美展示了现代大语言模型的底座:

  1. Q @ K.T: 算出词与词的相似度。
  2. Masked Fill: 强行切断通往"未来"的连接,确保因果性。
  3. Softmax @ V: 加权融合历史信息。
  4. 矩阵并行: 利用 T×TT \times T 矩阵的行独立性,一次前向传播就完成了序列中所有位置的"预测下一个词"的训练任务。