Contents

Summary: FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

  • Introduces FlashAttention-2, a new GPU kernel for exact (non-approximate) Transformer attention.

  • Targets long-context training/inference by cutting memory traffic to O(N) while pushing throughput close to matrix-multiply (GEMM) efficiency.

  • Achieves up to 2–3 × speed-up over FlashAttention-v1 and ~10 × over naïve PyTorch, hitting ≈ 73 % of A100 peak FLOPs and 72 % model-level FLOPs utilisation in GPT training.

  • Defers final soft-max rescaling and stores only log-sum-exp per row ⇒ far fewer scalar (non-matmul) FLOPs.

  • In addition to batch × heads, thread-blocks now split the sequence-length axis, raising SM occupancy for long-sequence / small-batch regimes.

  • Replaces “split-K” (each warp slices K/V) with “split-Q” (each warp slices Q) ⇒ almost zero inter-warp communication and shared-memory traffic. d

  • Measured forward, backward, and forward+backward TFLOPs / s across L ∈ {512…16 k}, head dims {64, 128}, with/without causal mask.

  • Compared against PyTorch standard, xFormers-cutlass, FlashAttention-v1 CUDA, and FlashAttention-Triton.

  • End-to-end training of GPT-style models (1.3 B & 2.7 B params, seq-len 2 k & 8 k) on 8× A100

  • Same kernels run unmodified on H100, showing further raw-throughput gains (up to ≈ 335 TFLOPs/s).

  • Current kernels target NVIDIA architectures; AMD/Intel GPUs and TPUs are not yet supported.

  • Four hand-chosen tile shapes; no auto-tuner provided.

  • No exploitation of H100-specific features (TMA, 4-gen Tensor-Cores, FP8) in the released code.

  • Focuses on dense attention; does not address algorithmic sparsity / locality that could extend context lengths

  • Port kernels to H100 with TMA/FP8, AMD ROCm, and other accelerators

  • Embed FlashAttention-2 in TVM/Triton autotuners so optimal block sizes and warp layouts are discovered automatically.

  • Fuse FlashAttention-2 kernels with block-sparse patterns to reach 100 k+ token context at high efficiency.

  • Evaluate FP8/BF8 or hybrid quantisation to trim memory bandwidth further without losing accuracy.

  • Integrate into vision, speech and multimodal Transformers to measure end-to-end gains beyond language modelling.

  • TMA (Tensor Memory Accelerator): A new Hopper-generation hardware path that streams tiles between HBM and registers/SRAM without explicit loads, freeing CUDA cores and reducing latency.
  • 4-gen Tensor Cores: The fourth-generation NVIDIA matrix-multiply units in H100 GPUs that add FP8/BF16 support and higher per-cycle throughput compared with A100’s third-generation units.