openalicelabs / academy
COURSE ARCH-01 LESSON 01 · 06 TOPIC FLASHATTENTION EST. READ ~12 MIN
OPENALICE LABORATORIES · EDUCATION PATH · ARCHITECTURE 01 · 06

Same answer.
Far less
memory.

Attention isn't slow because of the math — it's slow because the GPU keeps writing a giant N×N score matrix to slow memory and reading it back. FlashAttention computes the bit-for-bit identical result, but never lets that matrix touch slow memory. Faster, and linear in sequence length instead of quadratic.

FIG.00 — WHERE THE TIME GOES
loading…
FIG.0A — THE ONE IDEA · don't change the math, change WHERE it happens

Standard attention materializes the whole N×N matrix in slow HBM, then reads it all back. FlashAttention slices the same computation into small tiles that fit in fast on-chip SRAM, fuses everything into one kernel, and only ever writes the final output. The slow round-trip simply never happens.

INPUTQ, K, V — each an N×d matrix (length × head-dim)
OUTPUTO = softmax(QKᵀ)V — same result as standard attention
KINDIO-aware · EXACT (not an approximation)
MEMORYO(N) softmax stats instead of O(N²) score matrix
FROMDao et al. 2022 — now the default in PyTorch, vLLM, HF
01 / 07
What attention actually computes

Every token looks at every token.

Self-attention builds a table: for each query token, how much should it attend to each key token? For sequence length N that's an N×N table — and that single fact is the source of all the pain.

FIG.01 — THE N×N SCORE MATRIX · drag length N
N = 4 8

SEQUENCE LENGTH N8
SCORE-MATRIX ENTRIES (N²)64
OUTPUT ENTRIES (N·d)512
S = Q Kᵀ // N×N raw scores P = softmax(S) // N×N, row-wise O = P V // N×d output

Three steps. The middle two operate on the full N×N matrix. Double the context and the table quadruples. A 16K-token sequence wants a 256-million-entry matrix — per head, per layer.

The naive recipe builds that whole matrix in memory. FlashAttention's bet: you can compute the exact same O without ever holding S or P in full.

02 / 07
The bandwidth gap · why it's slow

The bottleneck is memory traffic, not math.

A modern GPU is compute-rich and bandwidth-poor. Its arithmetic units sit idle, waiting for the score matrix to crawl back from slow memory. Drag N and watch standard attention's HBM traffic explode quadratically — while FlashAttention's stays flat.

FIG.02 — HBM TRAFFIC · STANDARD (N²) vs FLASH (N) · drag the slider
512 2048

SEQUENCE LENGTH N2,048
STANDARD · MATRIX IN HBM
FLASH · STATS IN HBM (O(N))
HBM-TRAFFIC SAVED
THE NUMBERS (NVIDIA A100)

On-chip SRAM ≈ 19 TB/s but only ~20 MB total. The big HBM is 40–80 GB but only ~1.5–2.0 TB/s — roughly 10× slower. An elementwise step like softmax has almost no arithmetic per byte, so it's memory-bound: the data arrives long after the math could have finished.

Standard attention writes S to HBM, reads it back to softmax, writes P, reads P again to multiply by V. That's Θ(N²) HBM accesses of pure traffic. FlashAttention keeps S and P inside SRAM as tiles and persists only O(N) statistics — and FA1 proves that's asymptotically optimal: no exact-attention algorithm can do fundamentally less.

03 / 07
The core algorithm · never build the whole matrix

Compute it one tile at a time.

FlashAttention chops Q into row-blocks and K, V into column-blocks sized to fit in SRAM. Press Step → and walk the loop: load a tile, score it, update a running output — then drop it. The grey N×N grid below is never fully held in memory at once.

FIG.03 — LIVE TILE LOOP · only the lit tile lives in SRAM
tile 0 / 9
// the fused single-kernel loop for each block of K,V (columns): load Kⱼ,Vⱼ HBM → SRAM for each block of Q (rows): Sᵢⱼ = Qᵢ Kⱼᵀ // tile, in SRAM update Oᵢ & (mᵢ,ℓᵢ) // online softmax // drop the tile — never written to HBM

Each query row carries a tiny running state — its output so far plus two scalars (max mᵢ, sum ℓᵢ). Tiles arrive, contribute, and are discarded. The full N×N matrix exists only as a sequence of fleeting SRAM tiles — never all at once.

04 / 07
The trick that keeps it EXACT

Softmax, computed in one streaming pass.

Softmax needs the max and the sum-of-exponentials over the whole row — but tiling only shows you one block at a time. The online softmax fixes this: keep a running max and running sum, and rescale whenever a bigger value appears. Step through the blocks and watch the running answer stay correct.

FIG.04 — ONLINE SOFTMAX · stream blocks · rescale on a new max
ONE QUERY ROW · scores arrive block by block

RUNNING MAX  mᵢ−∞
RUNNING SUM  ℓᵢ0
STREAMING SOFTMAX (so far)
FULL-ROW SOFTMAX (reference)
MATCH?
// new block with local max m_new m ← max(mᵢ, m_new) ℓ ← exp(mᵢ−m)·ℓᵢ + Σ exp(s−m) Oᵢ ← exp(mᵢ−m)·Oᵢ + Σ exp(s−m)·Vⱼ mᵢ ← m // at the end: Oᵢ /= ℓᵢ

The magic is the rescale factor exp(mᵢ−m). When a block brings a larger score, every partial sum and partial output so far is gently corrected to reference the new maximum. Nothing is approximated — the order of summation changes, the answer does not.

That single property — streaming softmax == full-row softmax — is exactly why FlashAttention can claim the word exact. Watch the two readouts above stay locked together as you feed blocks.

05 / 07
Backward pass · trade compute for memory

Don't store it. Recompute it.

Training needs gradients, which normally want the stored N×N probabilities P. Keeping P would bring back O(N²) memory. So FlashAttention throws P away and recomputes each tile during backprop — persisting only one scalar per row. Toggle the data-flow and watch the slow HBM round-trip disappear.

FIG.05 — MEMORY HIERARCHY · STANDARD vs FLASH DATA-FLOW

PERSIST ONE SCALAR

L = m + log ℓ

The forward pass stores just the log-sum-exp per row. From it, any probability tile P=exp(S−L) is reconstructed on demand in SRAM.

COMPUTE-FOR-MEMORY ★

re-do matmuls, skip traffic

You pay a few extra FLOPs recomputing tiles. Because attention is memory-bound, that's still a net speedup — the saved HBM traffic dwarfs the redundant math.

KERNEL FUSION

one kernel, no scratch

Score → softmax → ×V are fused into a single GPU kernel. No intermediate S or P tensor is ever written out to HBM — there's nothing to round-trip.

06 / 07
FA1 → FA2 → FA3 · the same idea, tuned harder

One algorithm, three hardware rewrites.

FA1 invented the algorithm. FA2 and FA3 are increasingly hardware-specific re-implementations of the same algorithm, each squeezing closer to the GPU's raw matmul ceiling.

VersionYear · GPUKey engineering moveUtilisation (reported)
FlashAttention-12022 · A100 (Ampere)tiling + online softmax + recomputation + the optimality proof~25–40% of peak
FlashAttention-22023 · A100 (Ampere)fewer non-matmul FLOPs · parallelise over sequence length · better warp partitioning~50–73% (~225 TFLOP/s)
FlashAttention-32024 · H100 (Hopper)warp-specialisation (TMA) · ping-pong matmul/softmax overlap · FP8 path~75% (~740 TFLOP/s); ~1.2 PFLOP/s FP8
THE HEADLINE WINS

FA1: 3× faster GPT-2 training at 1K context, 15% on BERT-large vs the MLPerf 1.1 record — and the first to push a Transformer through Path-X (16K). FA2 doubled it again; FA3 added another 1.5–2× on H100. Every modern context-window expansion stands on this.

Note the trend: FA2 and FA3 don't change what is computed. They re-shape how the same tiled, online-softmax algorithm maps onto a specific GPU's matmul units, async copies, and precision modes. The algorithm is portable — the peak performance is not.

HONEST CAVEATS

It does not reduce asymptotic compute — attention is still Θ(N²) FLOPs; only the memory quadratic is gone. Single-token decode has just one query row, so plain FA underutilises the GPU (hence FlashDecoding). And FA3's FP8 path trades a little accuracy for speed — validate per-model.

Complementary, not rivals: FlashAttention makes exact O(N²) attention cheap; sparse/linear attention and state-space models change the math to escape O(N²). Long-context research uses both.

07 / 07
The takeaway · designing around the memory hierarchy

The lesson is bigger than attention.

FlashAttention is the canonical worked example of a deeper principle: design your algorithm around the memory hierarchy, not the FLOP count.

IO-AWARENESS

count the bytes moved

For memory-bound work, the right cost model is data movement, not arithmetic. FlashAttention took the bandwidth gap seriously and won — with a matching lower-bound proof to back it.

EXACT BEATS APPROX

when you can keep the answer

Approximate attention changes the output to save compute. FlashAttention saves far more — and keeps the answer bit-for-bit. Exactness is a feature you give up only when forced to.

UNDER OPENALICE

why big context is affordable

Alice is a consumer of FlashAttention via her providers — long system prompts and worker-mode drafting are economically possible because IO-aware kernels run underneath the stack.

01 · 06 — you made it

You understand
FlashAttention.

The N×N cost. The memory wall. Tiling, the online-softmax trick that keeps it exact, recomputation in the backward pass, and the FA1→FA3 ladder. Same math, same answer — but the giant matrix never touches slow memory. You now see why long context became affordable.

01·01 Attention & Transformers · the operation FlashAttention accelerates ✓ prereq
01·06 FlashAttention · IO-aware exact attention · tiling + online softmax ✓ complete
01·07 State-space models · the sub-quadratic alternative that changes the math next
01·08 Quantization · the FP8 precision tricks behind FA3's speed locked
Revisit · 01 · 01

Attention & Transformers →

Go back to the operation FlashAttention makes fast — the QKᵀ table, softmax, and the values it weights. Now you'll read it knowing exactly where the time was going.

openalicelabs