Skip to main content

FlashAttention 技术详解

1. 概述:它解决了什么问题?

FlashAttention 是一种革命性的算法,旨在优化 Transformer 模型中核心的注意力机制(Self-Attention)的计算过程。

它是当前大语言模型(LLM)技术栈中的基石。如果没有它,在现有硬件条件下,训练和部署支持**超长上下文窗口(Long Context,如 128k)**的模型几乎是不可能的。

一句话核心: FlashAttention 通过一种**"IO感知(IO-Aware)"**的方法,极大地减少了 GPU 显存的读写次数,使得注意力计算的速度提升了 2-4 倍,同时将显存占用从平方级爆炸降低为线性增长。

作者与版本演进

FlashAttention 由 Tri Dao(斯坦福大学)主导开发,合作者包括 Daniel Y. Fu、Stefano Ermon、Atri Rudra、Christopher Ré 等人。

版本发布时间主要改进
FlashAttention v12022 年 5 月首次提出 IO-Aware 的分块算法,实现 2-4x 加速
FlashAttention-22023 年 7 月优化并行策略和工作分配,速度提升约 2x(相比 v1)
FlashAttention-32024 年 7 月针对 Hopper GPU(H100)优化,利用异步执行和 FP8 支持

2. 背景痛点:标准 Attention 的瓶颈

要理解 FlashAttention,首先必须理解标准 Attention 慢在哪里。

2.1. 二次方复杂度瓶颈 (O(N2)O(N^2))

标准的 Attention 计算公式涉及一个巨大的中间矩阵:

Attention(Q,K,V)=Softmax(QKTdk)VAttention(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中,QQKTK^T 相乘会生成一个形状为 [N,N][N, N] 的注意力分数矩阵(NN 为序列长度)。

  • NN 较小时(如 512),这个问题不明显。
  • NN 增大时(如 32k),这个矩阵的元素数量呈平方级爆炸(32000210亿32000^2 \approx 10亿 个元素)。存储这个矩阵会瞬间耗尽 GPU 的显存(OOM)。

2.2. 真正的瓶颈:显存墙(Memory Wall)

在实际硬件运行中,最大的瓶颈往往不是计算速度(FLOPs),而是数据搬运速度。GPU 的存储结构分为两层:

  1. HBM (High Bandwidth Memory): 也就是常说的显存(如 24GB/80GB)。容量大,但读写速度相对较慢。
  2. SRAM (Static RAM): 位于计算单元旁边的高速缓存。读写极快,但容量极小(每组单元仅几百 KB)。

标准 Attention 的笨拙之处:

它会频繁地在 HBM 和 SRAM 之间搬运那个巨大的 N×NN \times N 矩阵。计算单元大部分时间都在"等待数据从 HBM 搬过来",而不是在计算。这被称为**"内存受限(Memory Bound)"**。


3. FlashAttention 的核心机制:分块计算 (Tiling)

FlashAttention 的核心思想是:IO感知(IO-Aware)。即算法的设计充分考虑了 HBM 和 SRAM 的速度差异,目标是尽可能让计算在快速的 SRAM 中完成,最大限度减少对慢速 HBM 的访问。

它通过 Tiling(分块) 技术实现这一目标。

3.1. Tiling 的具体做法

FlashAttention 不会一次性计算整个 N×NN \times N 矩阵。相反,它将输入的 Q,K,VQ, K, V 矩阵切分成许多小块(Blocks)。

  • 切分方向: 沿着序列长度 NN 的方向切分,但保留完整的特征维度 dd
  • 加载: 将一小块 QQ (Block QiQ_i) 和一小块 KK (Block KjK_j) 从 HBM 加载到极快的 SRAM 中。

3.2. 核心流程(在 SRAM 内完成)

在 SRAM 内部,进行以下"流水线"操作:

  1. 局部矩阵乘法: 计算这一小块的 Qi×KjTQ_i \times K_j^T。由于特征维度 dd 是完整的,计算出的这一小块 Nblock×NblockN_{block} \times N_{block} 的分数是准确的。
  2. 局部 Softmax: 对这小块分数进行局部的 Softmax 计算。
  3. 立刻相乘: 加载对应的一小块 VjV_j 到 SRAM,立刻与刚才的结果相乘。
  4. 写回与丢弃: 将计算出的部分结果累加到最终输出中,然后直接丢弃 SRAM 中的中间计算结果。

关键点: 那个巨大的 N×NN \times N 矩阵确实被计算出来了,但它是**"分批次在 SRAM 中产生,用完即焚"**,从未完整地存在于慢速的 HBM 显存中。


4. 核心技术难点与突破:Online Softmax

Tiling 策略面临一个巨大的数学挑战:Softmax 的全局依赖性

Softmax 的公式是 exiexj\frac{e^{x_i}}{\sum e^{x_j}}。这意味着,要算出任何一个 token 的最终概率,你需要知道它相对于整行所有 token 的分数之和(分母)。如果只是分块计算,分母是不全的。

解决方案:Online Softmax(在线 Softmax)

FlashAttention 采用了一种"边算边修正"的技巧。

  1. 先算局部: 在处理第一块 KK 时,基于目前见到的最大值和总和,算一个"临时的"Softmax 结果。
  2. 动态修正(Rescaling): 当处理后续的 KK 块时,如果发现了更大的分数,或者累加了更多的分母项,就利用数学公式回过头来**按比例缩放(Rescale)**之前计算的临时结果。
  3. 结果精确: 当遍历完所有的块,最终得到的结果在数学上与标准 Attention 一次性算出的结果是**完全一致(Exact)**的,并非近似解。

5. PyTorch 集成:SDPA(推荐方式)

PyTorch 2.0 引入了统一的 API:torch.nn.functional.scaled_dot_product_attention(简称 SDPA)。

这个 API 会自动选择当前硬件支持的最快内核(FlashAttention V2、Memory-Efficient Attention 或标准 C++ 实现)。你不需要显式导入 FlashAttention 库。

5.1. 迁移前:标准的慢速 Attention 实现

import torch
import torch.nn.functional as F
import math

def standard_attention(q, k, v, mask=None, dropout_p=0.0):
# q, k, v shape: (B, H, N, D)
d_k = q.size(-1)

# 1. 计算 QK^T 并缩放
# 结果 shape: (B, H, N, N)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

# 2. 应用 Mask (如果有)
if mask is not None:
# mask 通常是加性 mask,需要把不想关注的地方设为 -inf
scores = scores.masked_fill(mask == 0, float('-inf'))

# 3. Softmax
attn_weights = F.softmax(scores, dim=-1)

# 4. Dropout
if dropout_p > 0.0:
attn_weights = F.dropout(attn_weights, p=dropout_p)

# 5. 乘 V
# 结果 shape: (B, H, N, D)
output = torch.matmul(attn_weights, v)
return output

5.2. 迁移后:使用 PyTorch SDPA(一行代码)

FlashAttention 的复杂性全被封装在这一行代码里了:

import torch
import torch.nn.functional as F

# 确保你的输入是 fp16 或 bf16,并且在 GPU 上
# q, k, v = ... (move to cuda and cast to float16/bfloat16)

# 迁移后的代码
output = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask, # 可选,传递你的 mask
dropout_p=dropout_p, # 可选,dropout 概率
is_causal=False # 如果是 GPT 类的解码模型,这里设为 True
)

5.3. 使用注意事项

  • 数据类型: 输入应为 float16bfloat16 以获得最佳性能
  • 设备: 输入张量需要在 GPU 上
  • 因果注意力: 对于 GPT 类解码模型,设置 is_causal=True 可自动应用因果 mask
  • 自动后端选择: PyTorch 会自动选择最优的后端实现

6. 显存节省分析

FlashAttention 最显著的优势之一是大幅降低显存占用。以下是具体的分析。

6.1. 显存占用对比公式

假设输入序列长度为 NN,特征维度为 dd,batch size 为 BB,注意力头数为 HH

方法注意力矩阵显存复杂度
标准 AttentionB×H×N×NB \times H \times N \times NO(N2)O(N^2)
FlashAttention不存储完整矩阵O(N)O(N)

6.2. 具体数值示例

float16(2 bytes) 精度为例,单个注意力头的显存占用:

序列长度 NN标准 AttentionFlashAttention节省比例
1K (1,024)2 MB~几 KB~99%
4K (4,096)32 MB~几 KB~99.9%
16K (16,384)512 MB~几 KB~99.99%
64K (65,536)8 GB~几 KB~99.999%
128K (131,072)32 GB~几 KBOOM → 可运行

计算方式: 标准 Attention 显存 = N2×2N^2 \times 2 bytes 例如:163842×2=536,870,91216384^2 \times 2 = 536,870,912 bytes 512\approx 512 MB

6.3. 为什么能节省这么多?

标准 Attention 必须存储的中间结果:

  1. QKTQK^T 矩阵:[B,H,N,N][B, H, N, N] — 这是显存爆炸的元凶
  2. Softmax 输出:[B,H,N,N][B, H, N, N] — 同样巨大
  3. Dropout mask(如果有):[B,H,N,N][B, H, N, N]

FlashAttention 的做法:

  • 分块计算,每次只在 SRAM 中处理一小块(如 128×128128 \times 128
  • 计算完立即丢弃,不写回 HBM
  • 只存储最终输出 [B,H,N,d][B, H, N, d] 和用于反向传播的少量统计信息(softmax 归一化因子)

6.4. 实际训练中的影响

在训练 GPT 类模型时,显存节省带来的实际好处:

场景标准 AttentionFlashAttention
GPT-2 (1.5B), seq=1024可运行可运行,更快
GPT-3 (175B), seq=2048需要模型并行减少并行需求
LLaMA-2, seq=4096显存紧张轻松运行
Claude/GPT-4, seq=128K不可能成为可能

7. 总结:FlashAttention 的优势

特性标准 AttentionFlashAttention优势说明
显存占用O(N2)O(N^2) (二次方)O(N)O(N) (线性)序列长度翻倍,显存只增加一点点,不再爆炸
运行速度较慢 (受限于显存带宽)快 2-4 倍 (受限于计算能力)充分利用了 Tensor Cores 的算力,减少了等待时间
计算精度精确精确 (Exact)巧妙的数学处理保证了结果无损,不是近似计算
支持能力仅短序列 (如 < 4k)超长序列 (如 > 128k)使得大模型处理长文档、长对话成为可能

结论:

FlashAttention 并没有改变 Attention 的数学本质,而是通过极致的系统级优化(Tiling 和 Online Softmax),解决了硬件存储层级带来的效率瓶颈。它是算法与硬件完美结合的典范。


8. 参考资料