Skip to main content

FP8 训练原理深度解析:为什么它能训练大模型?

常见误解

很多人会有这样的理解:

"LLM 中权重基本在 0–1 范围内,在这个范围内通过缩放基本能用 FP8 表示所有数值,这就是 FP8 能训练的原因。"

这个理解抓到了一半关键点,但还差几个很重要的补充。让我们深入分析一下。


1. 你理解对的那一半

✅ LLM 权重分布确实相对集中

  1. 是的,大部分 LLM 权重都集中在一个相对小的范围

    • 一般是以 0 为中心,标准差比较小(很多在 [-0.1, 0.1] 这种级别)
    • 经过合理的初始化 + 归一化,训练过程中也不会跑飞太多
    • 这就意味着: 👉 你不需要一个超精细、超大范围的格式来存权重,FP8 这种粗一点的小浮点是有机会够用的
  2. 在这个范围内,FP8 确实能"凑合表示"这些值

    • 不管是 E4M3 还是 E5M2,在 [-1,1] 附近都有一些离散点
    • 这些点不是精确覆盖所有实数,但量化误差通常没大到把网络直接搞崩

所以:

✅ "权重范围不大 + FP8 有能力在这个范围内给出一个还行的近似" 确实是 FP8 能参与训练的 其中一个前提条件


2. 但这不是全部原因

① 训练不只看"权重":激活 + 梯度范围大得多

  • 权重大多在一个比较窄的区间没错

  • 但训练过程中还有:

    • 中间层的 激活(尤其是 attention/MLP 输出)
    • 各种 梯度
  • 它们的数值范围、分布变化都要比权重要暴躁得多:

    • 有的层输出会到几十、上百
    • 有些梯度会非常小(1e-5、1e-6)
  • FP8 能训练,不是因为"只要把权重缩到 0~1 就行了",而是整个 pipeline:

    • 激活 / 梯度 也会用 动态 scale + 合适 FP8 格式 来控制
    • 再配合 BF16/FP16 累加,保证关键计算别太炸

👉 所以:

训练里 FP8 能活下来,是整个混合精度系统一起作用的结果,而不仅是"权重范围刚好不大"。


② 真正训练时用的是:FP8 + BF16/FP16 混合管线

大模型 FP8 训练一般是这样的(简化版流程):

  1. 主权重(master weights)

    • BF16 / FP16 存在内存 / 显存里
    • 这是"真正被优化器更新的版本"
  2. 前向 / 反向 做 matmul 时:

    • 临时把主权重 cast 成 FP8(带 scaling)
    • 激活也在 FP8(也经 scaling)
    • 矩阵乘法在 Tensor Core 上用 FP8 输入 + BF16/FP16 累加
  3. 反向算出的梯度

    • 会以更高精度(BF16/FP16)累加、更新到 master weights 上

也就是说:

FP8 只是"计算路径上的低精度影分身", 真正承载优化过程的是 BF16/FP16 那份主权重。

这个设计是 FP8 能稳住训练的关键保护伞。 如果你强行所有东西都只用 FP8 存 & 算,训练基本离寄不远了。


③ "那 INT8 不是也能缩放到 0–1 吗?为啥它不行?"

这个反直觉点其实非常重要:

  • 是的:INT8 也完全可以通过 scale,把 [-128,127] 映射到 [0,1] 或 [-1,1]

    • 比如 scale=1/128 → 范围 [-1, ~1],步长 ~0.0078
  • 所以"靠缩放能覆盖 0–1 区间"这点,INT8 和 FP8 都做得到 👉 这就说明:FP8 能训练,不是因为"INT8 做不到覆盖权重范围"

真正的差别在于:

  1. FP8 是浮点,有指数 → 自带多个数量级,适应 "小到大都能凑合表示"

  2. INT8 是线性格子,scale 选好了一段就没法兼顾别的段

    • 要兼顾,就得不断动态调 scale(训练时非常难搞)

再加上前面说的混合精度管线:

  • FP8:配合 Transformer Engine + BF16 累加,生态已经被 NVIDIA 完整打通
  • INT8:想在训练里用好,需要一堆复杂的 8bit optimizer / per-channel 动态量化 / STE, 工程成本大、鲁棒性差,所以现在几乎只在推理里做 INT8

3. 更准确的完整说法

原来的理解可以修正成:

✅ 大部分 LLM 权重分布确实相对集中(比如在一个不太大的区间里), 这让我们有机会用 FP8 这种低精度来近似它们;

✅ 再结合 FP8 是浮点(指数 + 尾数)、动态 scaling、以及 BF16/FP16 主权重 + 累加, 整个混合精度 pipeline 让 FP8 在训练过程中数值误差仍然可接受;

❌ 但 "只因为权重在 0–1 范围内所以 FP8 可以训练" 这个说法太简单了, 真正的关键是:浮点特性 + 动态范围 + 混合精度设计 + LLM 对噪声的容忍度 一起保证训练不崩。


4. FP8 训练的完整流水线

让我们用一个简化的流程图来说明 FP8 训练的实际工作方式:

典型 FP8 混合精度训练流程

第 0 步:初始化
├─ 主权重 (master weights): BF16/FP16 格式
├─ 优化器状态 (m, v): BF16/FP16 格式
└─ Scaling factors: 用于 FP8 转换

第 1 步:前向传播
├─ 输入数据: BF16/FP16
├─ 转换: 主权重 BF16 → FP8 (带 scaling)
├─ 计算: FP8 × FP8 输入
├─ 累加: 用 BF16/FP16 精度累加
└─ 输出激活: 存为 FP8 (带 scaling)

第 2 步:反向传播
├─ 梯度计算: 主要在 FP8
├─ 梯度累加: 用 BF16/FP16
└─ 梯度存储: BF16/FP16

第 3 步:参数更新
├─ 用 BF16/FP16 梯度
├─ 更新 BF16/FP16 主权重
├─ 更新 BF16/FP16 优化器状态
└─ 重新计算 scaling factors

循环...

关键观察

  1. FP8 只在计算密集的 matmul 中使用(为了利用 Tensor Core 加速)
  2. 主权重和优化器状态始终保持高精度(BF16/FP16)
  3. 累加操作使用高精度(避免舍入误差累积)
  4. 动态 scaling 帮助适应数值范围变化

5. FP8 vs BF16 训练对比

纯 BF16 训练流程

主权重: BF16
前向计算: BF16 × BF16 → BF16 累加
激活: BF16
梯度: BF16
优化器: BF16

特点:

  • 简单,数值稳定
  • 显存占用较大
  • Tensor Core 利用率较低(相比 FP8)

FP8 混合精度训练流程

主权重: BF16 (存储)
前向计算: FP8 × FP8 → BF16 累加
激活: FP8 (带 scaling)
梯度: BF16
优化器: BF16

特点:

  • 复杂,需要精心设计 scaling
  • 显存占用中等(激活省了,权重没省)
  • Tensor Core 吞吐量可以更高
  • 需要硬件支持(H100/B100 等)

6. 为什么 FP8 训练能成功?

总结一下,FP8 训练能成功的关键因素:

① 硬件支持

  • H100/B100 的 Tensor Core 原生支持 FP8
  • FP8 输入 + BF16 累加的硬件管线

② 算法设计

  • 混合精度策略:计算用低精度,存储和更新用高精度
  • 动态 scaling:适应训练过程中的数值分布变化
  • 延迟 scaling 更新:避免频繁重算 scale 因子

③ 软件生态

  • NVIDIA Transformer Engine
  • PyTorch/JAX 的原生支持
  • 自动混合精度(AMP)框架

④ 模型特性

  • 大模型对噪声容忍度高:参数多,单个参数的量化误差影响小
  • 权重分布相对集中:大部分值在一个合理范围内
  • 过参数化:模型容量足够大,有冗余来吸收误差

7. 实践建议

何时使用 FP8 训练?

推荐使用:

  • 有 H100/H200/B100 等支持 FP8 的硬件
  • 训练大型 Transformer 模型(LLM, Vision Transformer)
  • 需要最大化训练吞吐量
  • 显存受限但又想训大模型

谨慎使用:

  • 小模型(< 1B 参数):收益不明显
  • 对数值敏感的任务:如高精度科学计算
  • 硬件不支持 FP8:退化成模拟会很慢

最佳实践

  1. 渐进式采用:先用 BF16 训练验证,再切换到 FP8
  2. 监控数值健康:关注 loss spike、梯度范数
  3. 调优 scaling:根据实际分布调整 scale 策略
  4. 选择性应用:关键层(如 embedding)可以保持高精度

8. 总结

FP8 能用于训练不是因为单一原因,而是多个因素协同作用的结果:

因素作用
权重分布集中降低对精度的需求
浮点格式提供动态范围,适应多数量级
混合精度设计计算加速,存储/更新保精度
动态 scaling适应训练过程的分布变化
硬件支持高效的 FP8 Tensor Core
模型容忍度大模型可以吸收量化噪声

核心要点:

FP8 只是"计算路径上的低精度影分身",真正承载优化过程的是 BF16/FP16 的主权重。 整个混合精度系统的设计,才是 FP8 训练成功的关键。


参考资料