Attention isn't slow because of the math — it's slow because the GPU keeps writing a giant N×N score matrix to slow memory and reading it back. FlashAttention computes the bit-for-bit identical result, but never lets that matrix touch slow memory. Faster, and linear in sequence length instead of quadratic.
loading…
Standard attention materializes the whole N×N matrix in slow HBM, then reads it all back. FlashAttention slices the same computation into small tiles that fit in fast on-chip SRAM, fuses everything into one kernel, and only ever writes the final output. The slow round-trip simply never happens.
Self-attention builds a table: for each query token, how much should it attend to each key token? For sequence length N that's an N×N table — and that single fact is the source of all the pain.
Three steps. The middle two operate on the full N×N matrix. Double the context and the table quadruples. A 16K-token sequence wants a 256-million-entry matrix — per head, per layer.
The naive recipe builds that whole matrix in memory. FlashAttention's bet: you can compute the exact same O without ever holding S or P in full.
A modern GPU is compute-rich and bandwidth-poor. Its arithmetic units sit idle, waiting for the score matrix to crawl back from slow memory. Drag N and watch standard attention's HBM traffic explode quadratically — while FlashAttention's stays flat.
On-chip SRAM ≈ 19 TB/s but only ~20 MB total. The big HBM is 40–80 GB but only ~1.5–2.0 TB/s — roughly 10× slower. An elementwise step like softmax has almost no arithmetic per byte, so it's memory-bound: the data arrives long after the math could have finished.
Standard attention writes S to HBM, reads it back to softmax, writes P, reads P again to multiply by V. That's Θ(N²) HBM accesses of pure traffic. FlashAttention keeps S and P inside SRAM as tiles and persists only O(N) statistics — and FA1 proves that's asymptotically optimal: no exact-attention algorithm can do fundamentally less.
FlashAttention chops Q into row-blocks and K, V into column-blocks sized to fit in SRAM. Press Step → and walk the loop: load a tile, score it, update a running output — then drop it. The grey N×N grid below is never fully held in memory at once.
Each query row carries a tiny running state — its output so far plus two scalars (max mᵢ, sum ℓᵢ). Tiles arrive, contribute, and are discarded. The full N×N matrix exists only as a sequence of fleeting SRAM tiles — never all at once.
Softmax needs the max and the sum-of-exponentials over the whole row — but tiling only shows you one block at a time. The online softmax fixes this: keep a running max and running sum, and rescale whenever a bigger value appears. Step through the blocks and watch the running answer stay correct.
The magic is the rescale factor exp(mᵢ−m). When a block brings a larger score, every partial sum and partial output so far is gently corrected to reference the new maximum. Nothing is approximated — the order of summation changes, the answer does not.
That single property — streaming softmax == full-row softmax — is exactly why FlashAttention can claim the word exact. Watch the two readouts above stay locked together as you feed blocks.
Training needs gradients, which normally want the stored N×N probabilities P. Keeping P would bring back O(N²) memory. So FlashAttention throws P away and recomputes each tile during backprop — persisting only one scalar per row. Toggle the data-flow and watch the slow HBM round-trip disappear.
The forward pass stores just the log-sum-exp per row. From it, any probability tile P=exp(S−L) is reconstructed on demand in SRAM.
You pay a few extra FLOPs recomputing tiles. Because attention is memory-bound, that's still a net speedup — the saved HBM traffic dwarfs the redundant math.
Score → softmax → ×V are fused into a single GPU kernel. No intermediate S or P tensor is ever written out to HBM — there's nothing to round-trip.
FA1 invented the algorithm. FA2 and FA3 are increasingly hardware-specific re-implementations of the same algorithm, each squeezing closer to the GPU's raw matmul ceiling.
| Version | Year · GPU | Key engineering move | Utilisation (reported) |
|---|---|---|---|
| FlashAttention-1 | 2022 · A100 (Ampere) | tiling + online softmax + recomputation + the optimality proof | ~25–40% of peak |
| FlashAttention-2 | 2023 · A100 (Ampere) | fewer non-matmul FLOPs · parallelise over sequence length · better warp partitioning | ~50–73% (~225 TFLOP/s) |
| FlashAttention-3 | 2024 · H100 (Hopper) | warp-specialisation (TMA) · ping-pong matmul/softmax overlap · FP8 path | ~75% (~740 TFLOP/s); ~1.2 PFLOP/s FP8 |
FA1: 3× faster GPT-2 training at 1K context, 15% on BERT-large vs the MLPerf 1.1 record — and the first to push a Transformer through Path-X (16K). FA2 doubled it again; FA3 added another 1.5–2× on H100. Every modern context-window expansion stands on this.
Note the trend: FA2 and FA3 don't change what is computed. They re-shape how the same tiled, online-softmax algorithm maps onto a specific GPU's matmul units, async copies, and precision modes. The algorithm is portable — the peak performance is not.
It does not reduce asymptotic compute — attention is still Θ(N²) FLOPs; only the memory quadratic is gone. Single-token decode has just one query row, so plain FA underutilises the GPU (hence FlashDecoding). And FA3's FP8 path trades a little accuracy for speed — validate per-model.
Complementary, not rivals: FlashAttention makes exact O(N²) attention cheap; sparse/linear attention and state-space models change the math to escape O(N²). Long-context research uses both.
FlashAttention is the canonical worked example of a deeper principle: design your algorithm around the memory hierarchy, not the FLOP count.
For memory-bound work, the right cost model is data movement, not arithmetic. FlashAttention took the bandwidth gap seriously and won — with a matching lower-bound proof to back it.
Approximate attention changes the output to save compute. FlashAttention saves far more — and keeps the answer bit-for-bit. Exactness is a feature you give up only when forced to.
Alice is a consumer of FlashAttention via her providers — long system prompts and worker-mode drafting are economically possible because IO-aware kernels run underneath the stack.
The N×N cost. The memory wall. Tiling, the online-softmax trick that keeps it exact, recomputation in the backward pass, and the FA1→FA3 ladder. Same math, same answer — but the giant matrix never touches slow memory. You now see why long context became affordable.
Go back to the operation FlashAttention makes fast — the QKᵀ table, softmax, and the values it weights. Now you'll read it knowing exactly where the time was going.