Machine Learning / 2022 / arXiv
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
If you profile a transformer on a modern GPU, attention isn't compute-bound. It's waiting on memory. FlashAttention is what happens when you stop counting FLOPs and start counting bytes.
There's a folk model of GPU performance that goes: bigger matrix multiplies, more FLOPs, slower. Under that model, attention is slow because the N×N attention matrix has N² entries, so doubling the sequence length quadruples the work. Make the FLOPs cheaper or skip some of them and attention gets faster.
That model is wrong. Or rather, it's measuring the wrong thing. If you profile a real transformer with a tool like nsight or nvprof, the attention layer spends most of its wallclock time not computing. It spends it waiting on memory transfers. The arithmetic units sit idle while the GPU shovels values back and forth between two kinds of memory. FlashAttention's contribution was, in a real sense, noticing this and writing a kernel that takes it seriously.
The paper is from 2022. Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, Christopher Ré. Title: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. The word that does the most work in that title is exact. Every prior attempt to make attention cheaper had relaxed the math: linear attention, low-rank approximations, locality-sensitive hashing, sparse patterns chosen heuristically. FlashAttention computes the same softmax, the same outputs, the same gradients. It just reorders the work so the GPU stops thrashing.
Two kinds of memory, very different speeds
A modern data-center GPU has a memory hierarchy that looks roughly like this. HBM (high-bandwidth memory) is the big DRAM you see on the spec sheet — 40 to 80 gigabytes, maybe 2 TB/s bandwidth. SRAM (also called shared memory or L1 cache) is on-chip, per-streaming-multiprocessor, tiny — maybe 100 KB to a few MB total, but with bandwidth around 19 TB/s. About 10× faster than HBM, and you can read/write it without going off-chip.
The names are misleading because both are RAM. The difference is where they sit. HBM is mounted on the package next to the GPU die, connected by a 4096-bit-wide bus. It's fast for DRAM — most CPUs would kill for 2 TB/s. But the GPU die has compute units that, if you actually keep them fed, can do hundreds of trillions of fp16 multiply-adds per second. At those rates, even 2 TB/s starts to look slow. Every time the kernel says "go fetch this tensor from HBM," the streaming multiprocessor stalls for a few hundred cycles waiting for the bytes to arrive. Multiply that by however many tensors and however many round trips and you have a problem.
SRAM is the opposite. It lives on the same silicon as the compute units, which means electrons travel a few millimeters instead of a few centimeters. It's small — on an A100 you get 192 KB per streaming multiprocessor and there are 108 SMs, so call it 20 MB across the whole chip if you're optimistic about how it's split. But you can read and write it at near-register speed, and the tensor cores can pull operands directly from it without ever touching HBM.
Every operation on the GPU follows the same rhythm: read inputs from HBM into SRAM, do the math, write outputs back to HBM. If your operation reads a lot of bytes per FLOP, you're memory-bound — the SMs sit idle waiting for HBM. If your operation does a lot of FLOPs per byte, you're compute-bound — HBM is fast enough to keep the SMs fed. The ratio that matters is arithmetic intensity (FLOPs per byte), and matrix multiply has high intensity, which is why GEMMs scream.
The roofline puts attention in the wrong half
There's a beautiful piece of folk knowledge from the HPC world called the roofline model. Plot arithmetic intensity on the x-axis and achievable throughput on the y-axis, both log-scale. You get two lines: a slope-1 line for memory-bound regime (throughput = AI × bandwidth), and a flat line at peak compute. They meet at the ridge point. On an A100 with fp16 tensor cores, the ridge sits at about 156 FLOPs per byte. Anything below that, you're starving the cores. Anything above, you've got more compute than you can feed and the bus has slack.
Attention, written naively, has terrible arithmetic intensity. Here's why.
Roofline + animated bus
Where the cycles actually go on an A100
Pick a kernel. Below the ridge point, the GPU is starved by HBM bandwidth and the tensor cores idle. Above it, the cores are saturated. Naive attention sits deep in the memory-bound regime; FlashAttention drags it across the ridge.
The ridge sits at ~156 FLOPs/byte on an A100. Naive softmax-attention lives near 1–4. FlashAttention pulls effective intensity up by an order of magnitude by keeping intermediates in SRAM. Same FLOPs, different bytes.
Standard attention, and the bytes it moves
Naive attention does this: compute S = QK^T (an N×N matrix). Compute P = softmax(S) (still N×N). Compute O = PV. Each of those three matrices is N×N. For N = 8K and fp16 storage, that's 128 MB per matrix per head. Now multiply by 32 heads and however many layers. You're paging hundreds of gigabytes through HBM per attention call.
More importantly, the intermediate S and P matrices get written to HBM and then read back. You compute S, write it. You read S, compute P, write it. You read P, compute O, write it. Each round trip is a stall. The actual matrix multiplies are cheap; you're paying for the bus.
Let's count. The forward pass of standard attention does roughly 4 N² d FLOPs (the two matmuls dominate). It reads Q, K, V once each (3 N d elements), writes S (N² elements), reads it back, applies softmax in place, reads P, writes O (N d). The dominant term in HBM traffic is N² — quadratic in sequence length, and importantly the same order of magnitude as the FLOP count. That's a recipe for a memory-bound kernel: when the bytes you move scale with the FLOPs, the bus is going to be the bottleneck.
And memory grows as N². Not just compute — memory itself becomes a hard wall. The reason 100K-context inference was hard wasn't that we couldn't compute it. It's that the attention matrix wouldn't fit. At N = 100K and fp32, that's 40 GB per head per layer just for the S tensor. You don't get to train that on a single A100, no matter how patient you are.
The idea: never write the attention matrix
FlashAttention's central observation is that you don't actually need the full N×N attention matrix to exist anywhere. You only need the final O = softmax(QK^T) V. The attention matrix is a means, not an end. So: tile Q, K, and V into blocks small enough to fit in SRAM, compute attention block-by-block, accumulate the output, and never let the full matrix materialize.
Done right, the only things that touch HBM are the inputs Q, K, V and the output O — each read or written exactly once. The intermediate S and P live their entire lives on-chip and vanish when the kernel exits. The HBM traffic drops from N² to N·d — linear in sequence length instead of quadratic. For N = 8K and d = 64, that’s 1 MB read and 1 MB written per head, instead of 256 MB.
This sounds straightforward and it is, except for one thing: softmax is not blockwise. You can't softmax a chunk of a row independently and then stitch the results together, because softmax requires normalization across the whole row. To make tiling exact rather than approximate, you need a streaming algorithm that maintains the right normalization as new blocks arrive.
This is the algorithmic kernel of the paper. Every other piece — the CUDA tiling, the recomputation trick for the backward pass, the block size choices — follows from getting this right. So spend a minute on it.
Online softmax, the trick that makes it exact
Standard softmax of a vector x is softmax(x_i) = exp(x_i - max(x)) / Σ exp(x_j - max(x)). The subtraction of max(x) is for numerical stability — it doesn't change the result mathematically, but without it you'd routinely overflow when x_i was on the order of 80 or more (which happens; attention scores are not bounded). It means you need a global pass to find the max before you can compute any output.
Online softmax says: as you stream blocks, maintain two running statistics per row. Keep m (running max so far) and ℓ (running sum of exp(x - m) so far). When a new block arrives with its own max m_tile and partial sum ℓ_tile:
- Compute the new global max: m\ = max(m, m_tile)*.
- Rescale the old running sum: it was computed relative to the old max, so multiply by
exp(m - m*)to put it on the new scale. - Rescale the new block's sum the same way:
exp(m_tile - m*). - Add them. That's the new running ℓ.
The output accumulator gets the same treatment. Each time the running max changes, you scale the partial output by exp(m_old - m_new) to correct for the fact that earlier contributions were normalized against a smaller max. When the last block has been processed, divide the accumulated output by the final ℓ. The result is bit-exact identical to running standard softmax over the full row.
Why does this work? Because softmax is an invariant operation: shifting all logits by a constant doesn’t change the output. The rescaling by exp(m_old - m_new) is exactly that shift, applied retroactively to contributions that were computed under a stale max. Algebraically: if O_partial was accumulated as Σ exp(x_j - m_old) · v_j, then O_partial · exp(m_old - m_new) = Σ exp(x_j - m_new) · v_j. The substitution preserves the relationship between numerator and denominator, which is all softmax cares about.
You can verify this on paper in about ten minutes, or watch it run below. Each tile arrives, the running statistics update, and at the end the answer matches naive softmax to the last decimal place.
The rescaling trick, step by step
Online softmax: streaming the unstreamable
Standard softmax needs a global max. We don’t have one when tiles arrive one at a time. The fix: carry a running m, ℓ, and O; every time the max bumps, rescale the carry by exp(m_old − m_new). The final answer is bit-exact.
At step 0, only tiles 0..0 have touched the row. Yet the running ℓ and O are alreadythe values you’d compute by knowing the global max — because every time the max rises, the carry shrinks by the right factor. That’s the entire trick.
The history of online softmax
Milakov and Gimelshein at NVIDIA published the streaming-max softmax in 2018 as a way to fuse softmax with the previous matmul. It got applied to attention by Markus Rabe and Charles Staats in 2021 ("Self-attention Does Not Need O(n²) Memory") — they showed you could do attention in O(N) memory if you were willing to pay extra FLOPs, but the result was slow on actual GPUs because their formulation didn’t play well with tensor-core utilization.
FlashAttention’s contribution wasn’t inventing online softmax. It was making it fast: getting the tile sizes right so each block stays in SRAM, fusing the entire forward pass into a single CUDA kernel, getting the warp-level reductions right, and proving that the resulting kernel is genuinely faster than calling cuBLAS three times. That last part sounds like a footnote and it isn’t — kernels written by humans usually lose to library code, and the FlashAttention paper had to demonstrate that this was an exception.
The forward pass, in a paragraph of pseudocode
Algorithm 1 of the paper, in plain English. Outer loop over query tiles Q_i. For each query tile, load it into SRAM and initialize per-row running statistics (m_i = -∞, ℓ_i = 0, O_i = 0). Inner loop over key/value tiles K_j, V_j. For each, load K_j and V_j into SRAM, compute the score tile S_ij = Q_i K_j^T, find the per-row max of S_ij, exponentiate, run the online-softmax update on m_i, ℓ_i, and accumulate O_i. After the inner loop, normalize: O_i = O_i / ℓ_i. Write O_i back to HBM.
The remarkable thing is what isn’t in that algorithm. There’s no S = QK^T step that allocates an N×N tensor. There’s no separate softmax pass. There’s no second matmul that reads the softmax output back from HBM. The entire forward pass is one fused loop, and the only HBM traffic is reading Q, K, V and writing O.
Tile sizes matter. The paper chooses B_r (rows per query tile) and B_c (columns per key/value tile) so that Q_i, K_j, V_j, plus the score tile S_ij, plus the running stats, all fit in SRAM simultaneously. On an A100 with 192 KB of SRAM per SM and d = 64, the constraint works out to something like B_r = B_c = 64 — so each iteration of the inner loop processes a 64×64 block of attention scores. There are N/B_r outer iterations and N/B_c inner iterations, so total inner iterations is N²/(B_r·B_c).
Total HBM accesses for the forward pass: O(N²d/M), where M is the SRAM size. Compare to standard attention’s O(N²). The factor of M/d is the win — on an A100 with 192 KB of SRAM and d=64, that’s about 1500×. In practice you don’t hit 1500× because of overheads, but you do hit 7–9× wallclock speedups.
What you actually save
The forward pass goes from O(N²) HBM accesses to O(N²/M) HBM accesses, where M is the SRAM size. In practice, with realistic block sizes, this is roughly a 10× reduction in memory traffic for typical sequence lengths. Wallclock speedups on the original A100 paper were 2-4× for forward, more for backward.
The bigger win is memory itself. Standard attention's memory is O(N²) — that's the matrix you have to allocate. FlashAttention's memory is O(N) — just the inputs, outputs, and a couple of running statistics. You can train sequences three or four times longer in the same VRAM. That's why long-context training became practical roughly when FlashAttention shipped.
The backward pass needs a small additional trick. To compute gradients you'd normally need S and P, which we deliberately didn't store. The fix is recomputation: in the backward pass, recompute S and P tile by tile, on the fly, using the same blockwise machinery. Recomputing them is cheap (compute is not the bottleneck, remember) and saves the memory you'd otherwise burn storing them. This is a tiny gem of an idea on its own — it inverts the usual checkpoint-vs-recompute tradeoff because here recomputing the cheap thing avoids storing the expensive thing.
Specifically: in standard attention, you save P (the softmax output) for backward. That’s an N² tensor in HBM. In FlashAttention, you save only m and ℓ — two scalars per query row, totaling 2N values. That’s 200,000× smaller for N=400, and it grows linearly in N instead of quadratically. The backward pass walks the same tile structure as forward, recomputes S and P on the fly from the saved m, ℓ, and emits gradients tile by tile. The arithmetic count goes up — about 4× more FLOPs in backward — but FLOPs were never the constraint, so the kernel comes out 4–5× faster overall.
IO accounting + online softmax
Same answer, different memory bill
Both panels compute identical attention. The naive version writes the N×N matrix to HBM; FlashAttention slides a tile through SRAM and never materializes it. The byte counters are real — set N and B and the gap blows up quadratically.
At N=2K the naive kernel ships 32.0 MB through HBM just for the S and P matrices; flash ships only 1.0 MB for inputs/output. The right-panel stats (m, ℓ, O)are running online-softmax invariants — when the new tile’s max exceeds the running max, ℓ and O get rescaled by exp(m_old − m_new) to stay numerically exact.
Why this didn't exist before 2022
All the ingredients had been around. Online softmax was published in 2018 (Milakov & Gimelshein). Tiling for matrix kernels is older than CUDA. The fundamental observation that attention is memory-bound shows up in profiles anyone could have run. Rabe and Staats had the algorithm in 2021. So why did the wallclock breakthrough come a year later?
The answer is mostly engineering. Writing a CUDA kernel that beats cuBLAS at its own game is hard. Tri Dao’s implementation gets a long list of details right — the warp tiling pattern, the Tensor Core MMA layouts, the memory access patterns for coalesced HBM reads, the on-chip register pressure, the choice of where to live in shared memory versus registers. Each of these decisions, made wrong, would give back the 7× win. The reason FlashAttention worked when the antecedents didn’t is that someone sat down and wrote 1,500 lines of CUDA correctly.
What FlashAttention actually did was the integration: take all those pieces, write a fused CUDA kernel that does the whole forward pass without intermediate writes, get the numerics right, get the block sizes right for real GPUs, and prove that it's a strict Pareto improvement (faster and less memory and same outputs). It's a paper that's mostly engineering, and that's the right kind of paper. Many things that look like algorithmic problems turn out to be IO problems once you look at the hardware honestly.
Sparsity, masks, and what the kernel can actually skip
Most production attention is causal — token i only sees tokens 0..i. The upper triangle of the attention matrix is masked to zero. Naive attention computes the upper triangle anyway and then multiplies by the mask. The FLOPs are wasted but at least cuBLAS gets a clean matmul shape. FlashAttention can do better: skip the masked tile-blocks entirely. A causal pattern halves the work.
More aggressive sparsity patterns extend the same idea. Block-sparse attention (Longformer, BigBird) defines a fixed pattern of which tile-blocks are active — a sliding window plus a few global anchor columns, say. With FlashAttention, the kernel iterates only over active blocks. The HBM traffic shrinks proportionally, the FLOPs shrink proportionally, and the math stays exact within whatever sparsity pattern you chose.
The point worth absorbing: with naive attention, sparsity is a post-processing trick — you compute everything and then zero things out. The kernel doesn’t actually save work; it saves you from learning from masked positions, but the bus traffic is identical. With FlashAttention, sparsity is baked into the loop. An empty tile-block is not loaded, not multiplied, not normalized into. The savings show up on the wallclock, not just on the gradient.
Sparsity in tile-block units
Skipping work, in chunks the kernel can actually skip
The attention matrix is a 16×16 grid of 128×128 tile-blocks. Toggle the sparsity pattern. Each unfilled block is real saved compute — FlashAttention skips the load and the FLOPs, not just the gradient.
With dense softmax-attention, masked entries still go through the kernel — you compute the score, you multiply by zero, you write the result. With FlashAttention’s tiled outer loop, an empty tile-block is not loaded. The save in HBM bytes is real and roughly proportional to the sparsity ratio. Causal masking alone is a free 2×.
FlashAttention-2 and the missing parallelism
The original FlashAttention parallelized across batch and head dimensions: each (batch, head) gets its own SM, and the SM walks tiles serially. That works fine when batch × heads is large enough to fill the GPU — say, batch=32, heads=32, which gives 1024 work items for an A100 with 108 SMs. But for long-context inference, batch shrinks to 1 and the heads alone don’t fill the chip. Most SMs sit idle.
FlashAttention-2 (Dao 2023) fixes this by parallelizing across the sequence dimension as well. Each query tile becomes its own work item, dispatched to its own SM. With N = 8K and B_r = 64, that’s 128 query tiles per head — plenty of work to keep the GPU busy even at batch 1. The paper also reorders the inner loops to reduce non-matmul FLOPs (the softmax exponentials, the rescales) which had been a small but real overhead in FA-1.
The result: roughly 2× over FA-1 on A100, and it scales much better at long sequence lengths. FA-2 was the version that made 32K and 64K training contexts genuinely cheap, not just possible.
FA-2 also made backwards passes fast in a way FA-1 didn’t. FA-1’s backward pass had a tricky atomicity problem — multiple SMs wanted to write gradients into the same query positions and had to serialize. FA-2 restructured the backward to have each SM own a query tile end-to-end, eliminating the atomic contention. Backward went from 2× slower than forward to roughly the same speed.
FlashAttention-3 and the H100 features
The H100 introduced two hardware features that FA-2 couldn’t fully exploit: TMA (the Tensor Memory Accelerator, an asynchronous DMA engine for shared memory loads) and wgmma (warp-group matrix multiply-accumulate, async tensor-core ops). FA-3 (Shah, Dao, et al. 2024) is a from-scratch kernel that uses both.
The high-level idea: while one warp-group does an MMA on the current tile, a different warp-group asynchronously loads the next tile via TMA. Memory and compute overlap. FA-2 was synchronous — load, compute, load, compute — and the compute units sat idle during loads. FA-3 keeps them busy. The reported throughput is roughly 740 TFLOPS of attention forward on H100, against an FP16 peak of 989 — about 75% of theoretical max, which is excellent for a kernel that includes softmax.
FA-3 also introduces FP8 attention, with calibration that keeps the error within bounds for typical transformer use. FP8 doubles throughput again on H100 — close to 1.2 PFLOPS sustained — though with quality caveats that limit it to inference and certain training configurations.
Wallclock vs sequence length
The gap is not a constant factor
Forward attention time, fp16, d=64, 8 heads, batch 1. Numbers calibrated from the FlashAttention papers and DAO’s blog reports. Naive scales like a square; FlashAttention scales like a square with a much smaller constant — and FlashAttention-3 on H100 changes the constant again.
At N=8K on A100, FlashAttention-2 already shaves the forward pass to a fraction of naive. At N=64Kthe gap is the difference between a model you can train and one you can’t. FA-3 on H100 raises the ceiling again — about 740 TFLOPS of effective throughput from a peak of 989. Numbers from Dao 2022, Dao 2023, Shah et al. 2024; smoothed for visualization.
What this enabled, downstream
Long context, basically. Pre-FlashAttention, going from 2K to 8K tokens was a memory crisis. Post-FlashAttention, going from 32K to 128K is a (still hard, but tractable) systems problem. Every long-context model you've used — Claude with its 200K+ window, Gemini with its million-token window, GPT-4's expanding context — runs on FlashAttention or a descendant kernel. There's no version of these models without IO-aware attention.
It also changed how people think about kernel writing. After FlashAttention, the field stopped treating PyTorch ops as the unit of optimization and started treating fused custom kernels as the unit. Mamba's selective scan, the SSD kernels in linear-attention variants, the various MoE routing kernels — they all owe a debt to FlashAttention's demonstration that a hand-written fused kernel can beat a stack of optimized library calls by an order of magnitude when memory traffic is the constraint.
And it made a strong case for the field paying attention to hardware. The transformer's quadratic-attention problem looked like an algorithmic problem ("replace softmax with something cheaper"). Two parallel research threads spent years on linear-attention variants that approximated softmax. FlashAttention basically said: don't approximate, just write the kernel correctly. The exact algorithm was fine; the implementation was leaving 10× on the floor.
The ripple effects are still showing up. Triton — OpenAI’s Python DSL for GPU kernels — got a huge boost in popularity because people wanted to write FlashAttention-style kernels for their own ops without learning CUDA. PyTorch added scaled_dot_product_attention as a built-in that dispatches to FlashAttention under the hood. The Hugging Face Transformers library picked it up as a default. Within about eighteen months of the paper, it was effectively the only attention kernel anyone ran in production.
Inference, KV caches, and paged attention
Inference has its own constraints. At decode time you have a single new query attending to a long KV cache, and the bottleneck shifts from the attention matrix to the cache itself. The query is one token; the keys and values are N tokens. The arithmetic intensity is even worse than at training time, because you can’t amortize the K, V load across many queries.
FlashDecoding (Dao 2023, blog post) addressed this by splitting the work along the sequence axis to keep all SMs busy. Instead of one SM walking the full KV cache, multiple SMs each walk a slice and combine results. Throughput at decode time goes from memory-bound at bandwidth/d to memory-bound at the actual HBM bandwidth — a 2-8× improvement depending on context length.
vLLM’s paged attention (Kwon et al. 2023) tackled a different inference problem: KV cache memory fragmentation. When you serve many concurrent requests with different sequence lengths, naive KV cache allocation leaves a lot of dead VRAM. Paged attention manages KV cache memory in fixed-size pages, similar to how an OS manages virtual memory. It’s not a kernel speedup — it’s a memory layout change — but it relies on FlashAttention’s flexibility about where blocks of K and V live in memory. The combination is what makes high-throughput LLM serving practical.
The broader pattern: the attention-is-memory-bound insight keeps paying off, just in different shapes. At training, the bottleneck is the N² materialization. At long-context inference, the bottleneck is the KV cache size. At serving, the bottleneck is fragmentation. Each one yields to a different fix, but they all share DNA with FlashAttention’s starting move: look at the bytes, not the FLOPs.
Limits, and what comes after
FlashAttention is still O(N²) in compute. It just doesn't materialize the matrix in HBM. For very long sequences (a million tokens, a billion tokens), the compute itself becomes the bottleneck again, and you need actual sub-quadratic methods — sparse attention, sliding windows, retrieval, or genuinely linear-time architectures like Mamba. FlashAttention extends the regime where exact attention is practical; it doesn't extend it forever.
The FLOP count gets ugly fast. At N = 1M, attention is 10¹² multiply-adds per head per layer. Even at 700 TFLOPS, that’s 1.4 ms per attention call. With 32 heads and 80 layers, you’re at 3.6 seconds per token. That’s a lower bound, before you start adding feedforward layers, attention serialization across heads, or any of the other realities. There’s a regime past which exact attention is not just slow — it’s untenable.
The current frontier is mixing FlashAttention with sub-quadratic structure: Mamba-style state-space mixing for long-range dependencies, FlashAttention for local context, retrieval for the rest. It’s a portfolio approach — different mechanisms for different scales — and it works because none of them are trying to do everything. FlashAttention’s contribution to that portfolio is that exact local attention is now cheap enough to be a building block rather than a budget item.
Type the rescale once and the kernel stops being scary
If you want to read the paper, focus on Algorithm 1 (the forward pass) and convince yourself that the rescaling math is correct. If you want to understand it, write a numpy version of online softmax over a 1D array. Once your hand has typed the rescale step once, the whole kernel stops being scary and starts being a slightly fiddly accounting exercise — which is what it actually is.
The deeper lesson, and the one the paper is becoming famous for, is about where to look. Attention had spent four years being studied as an algorithmic problem. The flame chart said it wasn’t one. Sometimes the bottleneck isn’t the bottleneck you think it is. Profile first.