目录

Summary: FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

本博客使用o3翻译,如有冲突请优先参考英文原文

  • 提出了 FlashAttention-2 —— 一种面向 Transformer 精确(非近似)注意力的新 GPU 内核。
  • 通过将显存访问复杂度降至 O(N),同时将吞吐率逼近矩阵乘 (GEMM) 的效率,从而瞄准长上下文训练/推理场景。
  • 相较 FlashAttention-v1 提速 2–3 ×,相较朴素 PyTorch 提速约 10 ×;在 GPT 训练中可达到 ≈ A100 峰值 FLOPs 的 73 %,模型级 FLOPs 利用率达 72 %。
  • 将最终 soft-max 重标定延后,仅存储每行的 log-sum-exp ⇒ 大幅减少标量(非 matmul)FLOPs。
  • 在线程块除了 batch × heads 维度外,新增对 序列长度轴 的划分 ⇒ 在长序列/小 batch 场景下提高 SM 占用率。
  • “split-Q” 取代 “split-K”(每个 warp 切分 Q 而非 K/V) ⇒ 近乎消除 warp 间通信与共享内存流量。
  • L ∈ {512…16 k}、head dim {64, 128}、是否因果掩码等条件下,测量前向、反向及前+反向 TFLOPs/s。
  • 与 PyTorch 标准实现、xFormers-cutlass、FlashAttention-v1 CUDA 及 FlashAttention-Triton 进行对比。
  • 在 8× A100 上端到端训练 GPT-风格模型(1.3 B & 2.7 B 参数,seq-len 2 k & 8 k)。
  • 同一内核 无需修改即可运行于 H100,裸算力最高达 ≈ 335 TFLOPs/s。
  • 目前仅支持 NVIDIA GPU;尚未适配 AMD/Intel GPU 及 TPU。
  • 仅提供四种手动挑选的 tile 形状;缺乏自动调优器。
  • 公开代码尚未利用 H100 新特性(TMA、第四代 Tensor Core、FP8)。
  • 关注点仍是稠密注意力;未触及稀疏/局部性算法以进一步扩展上下文长度。
  • 基于 H100 的 TMA/FP8、AMD ROCm 及其他加速器移植内核。
  • 将 FlashAttention-2 融入 TVM/Triton 自动调优 框架,让最优 block 大小与 warp 布局自动发现。
  • 与块稀疏模式融合,目标在高效率下处理 100 k+ token。
  • 评估 FP8/BF8 或混合 量化 以在不损失精度的前提下进一步压缩带宽需求。
  • 集成至视觉、语音与多模态 Transformer,验证语言建模之外的端到端收益。
  • TMA (Tensor Memory Accelerator):Hopper 架构新增硬件通路,可在 HBM 与寄存器/SRAM 间流式传输 tile,无需显式加载,减轻 CUDA 核心负担并降低延迟。
  • 第四代 Tensor Core:H100 上的矩阵乘单元,新增 FP8/BF16 支持,且每周期吞吐高于 A100 的第三代单元。