跳到主要内容

Transformer 训练与推理的并行性差异

这是 Transformer 架构中最反直觉、也最核心的问题:既然 Transformer 是一个 token 一个 token 预测的,怎么能说它是并行的?

答案是:Transformer 在推理(生成)阶段确实是串行的,但在训练阶段是高度并行的。 这正是它能够取代 RNN/LSTM 成为主流架构的关键原因。


1. 训练阶段:并行的魔法

在训练时,我们手里已经有了完整的"标准答案"(Ground Truth)。比如要训练这句话:"我爱人工智能"

RNN/LSTM 的做法(被迫串行)

RNN 的结构决定了它必须按时间步走:

  1. 输入"我",算出隐藏状态 h1h_1,预测"爱"
  2. 必须等 h1h_1 算完,才能把 h1h_1 和"爱"传进去,算出 h2h_2,预测"人工"
  3. 必须等 h2h_2 算完,才能算出 h3h_3...

这一步卡一步,就像接力赛跑,前一棒没跑完,后一棒不能动。GPU 这种擅长大规模矩阵运算的硬件,此时大部分算力都在空转等待。

Transformer 的做法(极致并行)

在训练 Transformer 时,使用了 Teacher Forcing(教师强制) 策略。因为已经知道整句话是"我爱人工智能",可以一次性把所有 token 全部扔进模型。

通过自注意力机制位置编码,模型可以在同一时刻计算:

  • 在位置 1:看到"我",预测"爱"
  • 在位置 2:看到"我爱",预测"人"
  • 在位置 3:看到"我爱人",预测"工"
  • ...

关键点: 虽然逻辑上位置 2 依赖位置 1 的信息,但在数学上,这只是一个巨大的矩阵乘法。GPU 不需要等位置 1 算完再算位置 2,它是一次性把整个矩阵算出来的。

训练阶段的"并行":把整个序列的时间维度(Time Step)给并行化了。 这让训练速度比 RNN 快了成百上千倍。


2. 推理/生成阶段:回归串行

当真正用 ChatGPT 聊天时(推理阶段),模型不知道未来会说什么,也不知道下一个词会是什么:

  1. 输入:"你好"
  2. 模型算出:"吗"(把它加到输入里)
  3. 输入:"你好吗"
  4. 模型算出:"?"

这确实是一个 token 一个 token 计算的(Autoregressive)。在这个阶段,Transformer 无法在时间上并行。


3. Mask:防止"剧透"

你可能会问:如果在训练时是一次性并行输入的,那模型计算"我"的时候,岂不是偷看到了后面的"爱"?

为了解决这个问题,Transformer 在训练时使用了 Causal Mask(因果掩码)。它是一个下三角矩阵,强行把"未来"的信息遮住:

  • 算第 1 个词时,把第 2, 3, 4... 个词的数据遮住
  • 算第 2 个词时,把第 3, 4... 个词的数据遮住

虽然数据是一起输进去并行的,但通过 Mask,保证了逻辑上依然遵守"只能看过去,不能看未来"的规则。


4. 本质区别:计算依赖的"变量"不同

要理解为什么 Transformer 能并行而 RNN 不能,关键在于:计算第 n 个词的时候,需不需要等第 n-1 个词算完?

RNN:致命的"隐藏状态"依赖

RNN 的计算公式:

h10=f(x10,h9)h_{10} = f(x_{10}, h_9)

  • x10x_{10}:第 10 个词的输入。可以直接给。
  • h9h_9:第 9 步算出的"隐藏状态"。必须等第 9 步算完!

h9h_9 依赖 h8h_8h8h_8 依赖 h7h_7... 这就像叠罗汉,即使把 10 个人都叫到现场,第 10 个也没法直接站上去,必须等前面的人一个个站好。

Transformer:只依赖原始输入

Transformer 的注意力机制彻底抛弃了递归传递:

Output10=Attention(x10,all previous x)Output_{10} = Attention(x_{10}, \text{all previous } x)

它说:"我不需要上一步的'记忆状态'。我只要回头看一眼原始的 x1x_1x9x_9 是什么,自己现场算一下就行。"

因为 x1x_1x10x_{10} 都是原始数据,都在矩阵里躺着。所以 GPU 可以直接根据矩阵乘法,独立地、同时地算出每个位置的结果。


5. 形象类比

类比一:写日记 vs 批改卷子

场景模式说明
RNN(训练 & 推理)写日记必须写完今天的,才能写明天的
Transformer(推理)写日记写完一个字,才能想下一个字
Transformer(训练)批改卷子老师手里有完整答案,可以同时批改所有题目

类比二:接力跑 vs 大合唱

  • RNN:接力跑。前一棒没跑完,后一棒绝对不能动。
  • Transformer:大合唱。指挥棒一挥,大家看着谱子同时张嘴唱。第 10 个人不需要等第 9 个人唱完。

类比三:贪吃蛇 vs 拼图

  • RNN:贪吃蛇。身体一节一节长出来,头必须连着脖子,脖子连着身子。
  • Transformer:拼图。手里有所有拼图块,可以同时拼左上角和右下角。

6. Batch Size vs Sequence Length 并行

需要区分两种并行:

并行类型RNNTransformer
Batch 并行(同时训练多个句子)可以可以
Sequence 并行(同一句话内部多个位置同时算)不能可以

Transformer 的优势在于它两种并行都能做,这就是双重加速。


7. 总结

阶段RNNTransformer
训练串行(隐藏状态依赖)并行(只依赖原始输入)
推理串行串行

Transformer 在生成时是串行的,但它之所以伟大,是因为它让训练过程变成了并行

正是因为训练能并行,我们才能训练出 GPT-4 这种拥有万亿参数、吞噬了整个互联网数据的庞然大物。如果是 RNN,可能练到下个世纪都练不完。


延伸阅读

既然知道了推理是串行的,可以进一步了解:

  • KV Cache(键值缓存):推理阶段专门用来加速的技巧,虽然不能改变串行本质,但能避免重复计算
  • Speculative Decoding(投机解码):通过小模型预测多个 token,大模型并行验证,加速推理