FlashAttention: Why Moving Fewer Bytes Beats Doing Fewer FLOPs
The surprising idea behind FlashAttention is that the attention equation was not necessarily the main problem. Moving its intermediate matrices between GPU memory levels was. The algorithm gets faster by changing the order of work so fewer bytes travel to and from high-bandwidth memory.
This article starts with intuition, then moves into the mechanisms and production details. You can stop after the worked example and retain the core idea, or continue into the performance model and operational edge cases.
Start with the intuition
A cook can prepare a dish by carrying every ingredient to a distant pantry after every step, or by bringing a small tray of ingredients to the counter and finishing a tile of work locally. The recipe is unchanged. The traffic is radically lower.
What actually happens
Standard attention forms scores QK^T, applies softmax, and multiplies by V. A naive implementation materializes the large score and probability matrices in HBM. FlashAttention tiles Q, K, and V so blocks fit in on-chip SRAM and fuses the operations.
Softmax normally needs the maximum and normalization sum over a full row. FlashAttention uses an online softmax recurrence, maintaining a running maximum and rescaling partial sums as new tiles arrive. That makes tiling exact up to normal floating-point behavior.
The asymptotic arithmetic complexity remains quadratic for dense attention. The gain comes from IO complexity and fusion: fewer HBM reads and writes, fewer intermediate tensors, and fewer kernel boundaries. This distinction matters when explaining why it is fast.
A worked example
For a sequence of 8,192 tokens, the attention score matrix contains about 67 million elements per head. At two bytes each, that is about 128 MiB per head before probabilities and other intermediates. A tiled kernel avoids writing the full matrix to HBM, even though it still evaluates the needed dot products.
The performance model
The improvement depends on sequence length, head dimension, data type, GPU architecture, masks, and kernel support. Short sequences or unsupported shapes may show modest gains because launch overhead and other model layers dominate.
Expert lens
FlashAttention is an algorithm family, not a magic compiler flag. Version-specific kernels change work partitioning, parallelism, and support for causal masks, grouped-query attention, variable lengths, and newer GPU instructions. Confirm which backend path actually ran rather than trusting a configuration label.
Where it wins
- Long-context prefill and training
- Attention-heavy models on supported GPUs
- Memory pressure caused by score intermediates
Where it disappoints
- Claiming linear-time attention
- Assuming every attention call uses the flash kernel
- Ignoring padding and variable-length packing
- Benchmarking a microkernel but reporting model speedup
Production checklist
- Verify dtype, head dimension, mask, and GPU support
- Profile the selected attention backend
- Use realistic variable sequence lengths
- Separate prefill gains from decode gains
- Compare peak memory as well as latency
What to measure
- Attention kernel duration and achieved bandwidth
- HBM bytes read and written
- Kernel selection and fallback count
- Prefill latency by prompt length
- Peak temporary memory
From one GPU to a production service
A framework-level demo may call a supported attention function and see a speedup. A serving platform has many model architectures and request shapes. Qualification must map each model revision to a proven kernel path for causal prefill, decode, variable lengths, GQA, sliding windows, and multimodal positions.
Long-context traffic should be bucketed by the shapes that select different kernels. Averages hide cliffs. One unsupported head dimension or mask can send a minority of requests through a slower fallback and dominate p99 latency.
Treat the attention backend as a deployable dependency. Pin versions, record kernel selection in traces, and rerun numerical plus performance canaries when CUDA, PyTorch, driver, or model configuration changes.
Design-review questions
- Which exact attention kernel handles every production shape?
- What is the fallback and how is it detected?
- Does variable-length packing preserve the intended mask?
- Are gains in prefill, decode, or both?
- What numerical tolerance is accepted across backend upgrades?
How it connects to the rest of the series
PagedAttention addresses KV allocation rather than score-matrix IO. Mixed precision changes the bytes moved. Graph optimization and fused kernels reduce adjacent overhead around the attention kernel.
From equation to implementation
For one query tile, the kernel streams K and V tiles through SRAM. It tracks a running row maximum m and normalization l. When a new tile raises m, previously accumulated output is rescaled before adding the new tile contribution. This recurrence avoids storing the full probability matrix while producing exact dense attention.
The backward pass recomputes selected intermediates rather than reading a stored attention matrix. That trades extra arithmetic for much lower memory traffic. During inference, the most visible benefit is usually long-prompt prefill; single-token decode has a different shape and may use specialized paged or fused decode kernels.
Implementation sketch
for Q_tile in Q:
m = -infinity; l = 0; out = 0
for K_tile, V_tile in blocks:
scores = Q_tile @ transpose(K_tile)
apply_causal_mask(scores)
m_new = max(m, rowmax(scores))
p = exp(scores - m_new)
out = out * exp(m - m_new) + p @ V_tile
l = l * exp(m - m_new) + rowsum(p)
m = m_new
write(out / l)Capacity planning
Workspace and tile choices depend on head dimension, sequence length, and GPU SRAM/register limits. A kernel that is ideal for head dimension 128 may spill registers or fall back for another shape. Keep model-shape support in the deployment qualification matrix.
Benchmarking without fooling yourself
- Profile kernel names to prove the intended backend ran.
- Sweep sequence length, causal mode, head dimension, and dtype.
- Measure HBM traffic and peak temporary memory, not just FLOPs.
- Report full-model prefill latency alongside microkernel timing.
A production failure to design for
A model upgrade changes head dimension to an unsupported value. The framework silently selects a math fallback, and p99 TTFT doubles only for that model. Alert on backend selection and maintain a canary benchmark for every deployed architecture.
Primary references
The takeaway
FlashAttention is a lesson in systems thinking: the fastest equation is often the one that respects the memory hierarchy.
