Summary: FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
0. Materials
1. What is the paper about?
Introduces FlashAttention-2, a new GPU kernel for exact (non-approximate) Transformer attention.
Targets long-context training/inference by cutting memory traffic to O(N) while pushing throughput close to matrix-multiply (GEMM) efficiency.
Achieves up to 2–3 × speed-up over FlashAttention-v1 and ~10 × over naïve PyTorch, hitting ≈ 73 % of A100 peak FLOPs and 72 % model-level FLOPs utilisation in GPT training.
2. What is new compared to prior work?
Defers final soft-max rescaling and stores only
log-sum-exp
per row ⇒ far fewer scalar (non-matmul) FLOPs.In addition to batch × heads, thread-blocks now split the sequence-length axis, raising SM occupancy for long-sequence / small-batch regimes.
Replaces “split-K” (each warp slices
K/V
) with “split-Q” (each warp slicesQ
) ⇒ almost zero inter-warp communication and shared-memory traffic. d
3. What experiments were run to support the arguments in this paper?
Measured forward, backward, and forward+backward TFLOPs / s across
L ∈ {512…16 k}
, head dims{64, 128}
, with/without causal mask.Compared against PyTorch standard, xFormers-cutlass, FlashAttention-v1 CUDA, and FlashAttention-Triton.
End-to-end training of GPT-style models (1.3 B & 2.7 B params, seq-len 2 k & 8 k) on 8× A100
Same kernels run unmodified on H100, showing further raw-throughput gains (up to ≈ 335 TFLOPs/s).
4. What are the shortcomings/limitations of this paper?
Current kernels target NVIDIA architectures; AMD/Intel GPUs and TPUs are not yet supported.
Four hand-chosen tile shapes; no auto-tuner provided.
No exploitation of H100-specific features (TMA, 4-gen Tensor-Cores, FP8) in the released code.
Focuses on dense attention; does not address algorithmic sparsity / locality that could extend context lengths
5. What is a reasonable next step to build upon this paper?
Port kernels to H100 with TMA/FP8, AMD ROCm, and other accelerators
Embed FlashAttention-2 in TVM/Triton autotuners so optimal block sizes and warp layouts are discovered automatically.
Fuse FlashAttention-2 kernels with block-sparse patterns to reach 100 k+ token context at high efficiency.
Evaluate FP8/BF8 or hybrid quantisation to trim memory bandwidth further without losing accuracy.
Integrate into vision, speech and multimodal Transformers to measure end-to-end gains beyond language modelling.
Appendix
- TMA (Tensor Memory Accelerator): A new Hopper-generation hardware path that streams tiles between HBM and registers/SRAM without explicit loads, freeing CUDA cores and reducing latency.
- 4-gen Tensor Cores: The fourth-generation NVIDIA matrix-multiply units in H100 GPUs that add FP8/BF16 support and higher per-cycle throughput compared with A100’s third-generation units.