Skip to content
3/20 - FlashAttention: Why Moving Fewer Bytes Beats Doing Fewer FLOPs

3/20 - FlashAttention: Why Moving Fewer Bytes Beats Doing Fewer FLOPs

The surprising idea behind FlashAttention is that the attention equation was not necessarily the main problem. Moving its intermediate matrices between GPU memory levels was. The algorithm gets faster by changing the order of work so fewer bytes travel to and from high-bandwidth memory.

This article starts with intuition, then moves into the mechanisms and production details. You can stop after the worked example and retain the core idea, or continue into the performance model and operational edge cases.

Start with the intuition

A cook can prepare a dish by carrying every ingredient to a distant pantry after every step, or by bringing a small tray of ingredients to the counter and finishing a tile of work locally. The recipe is unchanged. The traffic is radically lower.

MECHANISM FLOWFlashAttention: request path01HBMLoad Q K V tilesAvoid full score matrix02On-chip SRAMOnline softmaxAccumulate output tile03HBMWrite final outputExact attention resultINPUT → TRANSFORM → OUTCOME
Follow the state and work from left to right.

How to read this diagram: Start with HBM, where load q k v tiles. The middle stage, On-chip SRAM, online softmax. The final stage, HBM, shows the observable result: write final output. The arrows describe dependency order, not necessarily separate services.

What actually happens

Standard attention forms scores QK^T, applies softmax, and multiplies by V. A naive implementation materializes the large score and probability matrices in HBM. FlashAttention tiles Q, K, and V so blocks fit in on-chip SRAM and fuses the operations.

Softmax normally needs the maximum and normalization sum over a full row. FlashAttention uses an online softmax recurrence, maintaining a running maximum and rescaling partial sums as new tiles arrive. That makes tiling exact up to normal floating-point behavior.

The asymptotic arithmetic complexity remains quadratic for dense attention. The gain comes from IO complexity and fusion: fewer HBM reads and writes, fewer intermediate tensors, and fewer kernel boundaries. This distinction matters when explaining why it is fast.

A worked example

For a sequence of 8,192 tokens, the attention score matrix contains about 67 million elements per head. At two bytes each, that is about 128 MiB per head before probabilities and other intermediates. A tiled kernel avoids writing the full matrix to HBM, even though it still evaluates the needed dot products.

The performance model

The improvement depends on sequence length, head dimension, data type, GPU architecture, masks, and kernel support. Short sequences or unsupported shapes may show modest gains because launch overhead and other model layers dominate.

PHASE FITWhere FlashAttention changes inferencePREFILLMany prompt tokens in parallelHigh arithmetic intensityStrong IO savings on long attentionDECODEOne new token per iterationWeight and KV bandwidth pressureBenefit depends on decode kernel shapePROVE IT WITHTTFT and attention-kernel timeDEPLOYMENT DECISIONValidate kernel path for every shape
Prefill and decode run the same model but expose different bottlenecks and SLOs.

How to read this diagram: The left panel asks how FlashAttention changes prompt processing and TTFT; the right asks how it changes iterative generation and inter-token latency. The bottom row names the metric that must improve and the deployment choice justified by that evidence. Optimizing the wrong phase can add complexity without changing the user-visible bottleneck.

Expert lens

FlashAttention is an algorithm family, not a magic compiler flag. Version-specific kernels change work partitioning, parallelism, and support for causal masks, grouped-query attention, variable lengths, and newer GPU instructions. Confirm which backend path actually ran rather than trusting a configuration label.

TRADE-OFF MAPFlashAttention: the tradeoffBASELINEMaterialized attentionWrites score matrix to HBMMultiple kernel boundariesHigh intermediate memorySimple conceptual graphVSOPTIMIZEDFlashAttentionTiles through on-chip SRAMFused exact computationLower IO and memory useShape-sensitive kernel pathMEASURE BOTH SIDES UNDER THE SAME WORKLOAD
The optimization changes where the system spends compute, memory, bandwidth, or waiting time.

How to read this diagram: The left panel is the baseline, Materialized attention, characterized by writes score matrix to hbm and multiple kernel boundaries. The right panel applies FlashAttention, changing the cost profile to tiles through on-chip sram and fused exact computation. Compare both under the same request shape and load; the optimized side is not automatically better for every workload.

Where it wins

  • Long-context prefill and training
  • Attention-heavy models on supported GPUs
  • Memory pressure caused by score intermediates

Where it disappoints

  • Claiming linear-time attention
  • Assuming every attention call uses the flash kernel
  • Ignoring padding and variable-length packing
  • Benchmarking a microkernel but reporting model speedup

Production checklist

  • Verify dtype, head dimension, mask, and GPU support
  • Profile the selected attention backend
  • Use realistic variable sequence lengths
  • Separate prefill gains from decode gains
  • Compare peak memory as well as latency

What to measure

  • Attention kernel duration and achieved bandwidth
  • HBM bytes read and written
  • Kernel selection and fallback count
  • Prefill latency by prompt length
  • Peak temporary memory

From one GPU to a production service

A framework-level demo may call a supported attention function and see a speedup. A serving platform has many model architectures and request shapes. Qualification must map each model revision to a proven kernel path for causal prefill, decode, variable lengths, GQA, sliding windows, and multimodal positions.

Long-context traffic should be bucketed by the shapes that select different kernels. Averages hide cliffs. One unsupported head dimension or mask can send a minority of requests through a slower fallback and dominate p99 latency.

Treat the attention backend as a deployable dependency. Pin versions, record kernel selection in traces, and rerun numerical plus performance canaries when CUDA, PyTorch, driver, or model configuration changes.

Design-review questions

  • Which exact attention kernel handles every production shape?
  • What is the fallback and how is it detected?
  • Does variable-length packing preserve the intended mask?
  • Are gains in prefill, decode, or both?
  • What numerical tolerance is accepted across backend upgrades?

How it connects to the rest of the series

PagedAttention addresses KV allocation rather than score-matrix IO. Mixed precision changes the bytes moved. Graph optimization and fused kernels reduce adjacent overhead around the attention kernel.

From equation to implementation

For one query tile, the kernel streams K and V tiles through SRAM. It tracks a running row maximum m and normalization l. When a new tile raises m, previously accumulated output is rescaled before adding the new tile contribution. This recurrence avoids storing the full probability matrix while producing exact dense attention.

The backward pass recomputes selected intermediates rather than reading a stored attention matrix. That trades extra arithmetic for much lower memory traffic. During inference, the most visible benefit is usually long-prompt prefill; single-token decode has a different shape and may use specialized paged or fused decode kernels.

Implementation sketch

for Q_tile in Q:
    m = -infinity; l = 0; out = 0
    for K_tile, V_tile in blocks:
        scores = Q_tile @ transpose(K_tile)
        apply_causal_mask(scores)
        m_new = max(m, rowmax(scores))
        p = exp(scores - m_new)
        out = out * exp(m - m_new) + p @ V_tile
        l = l * exp(m - m_new) + rowsum(p)
        m = m_new
    write(out / l)

Capacity planning

Workspace and tile choices depend on head dimension, sequence length, and GPU SRAM/register limits. A kernel that is ideal for head dimension 128 may spill registers or fall back for another shape. Keep model-shape support in the deployment qualification matrix.

Benchmarking without fooling yourself

  • Profile kernel names to prove the intended backend ran.
  • Sweep sequence length, causal mode, head dimension, and dtype.
  • Measure HBM traffic and peak temporary memory, not just FLOPs.
  • Report full-model prefill latency alongside microkernel timing.

A production failure to design for

A model upgrade changes head dimension to an unsupported value. The framework silently selects a math fallback, and p99 TTFT doubles only for that model. Alert on backend selection and maintain a canary benchmark for every deployed architecture.

OPERATING LOOPOperational loop1QualifyGPU and head shapeMask and dtype2ProfileKernel and HBM bytesFallback detection3ComparePrefill latencyPeak memory4ShipCanary every modelPin known backendMEASURE → LEARN → REPEAT
Treat optimization as a measured loop, not a one-time flag.

How to read this diagram: The operating cycle moves from Qualify to Profile, then Compare and Ship. The return arrow matters: production evidence from the fourth step must change the assumptions and limits in the first, otherwise the optimization gradually drifts away from the workload it serves.

Deeper engineering guide

Standard attention materializes or repeatedly moves the N x N score matrix through high-bandwidth memory. FlashAttention changes the execution order: tiles of Q, K, and V are loaded into on-chip SRAM, partial scores are computed, and an online softmax maintains running maxima and normalization terms. The exact attention result is preserved while HBM traffic falls dramatically.

The tiled attention dataflowLoad tileMove Q, K, VHBM to SRAMBound working setScoreCompute QKᵀ tileApply scale and maskKeep local valuesNormalizeUpdate row maximumRescale old sumOnline softmaxAccumulateMultiply by V tileWrite final outputNever store N²The algorithm saves traffic by fusing operations around a bounded on-chip tile.
FlashAttention is an IO-aware exact algorithm, not an approximation to attention.

How to read this diagram: Follow the state from Load tile through Score and Normalize to Accumulate. Each box is an ownership or computation boundary. In particular, the algorithm saves traffic by fusing operations around a bounded on-chip tile. A real implementation may fuse boxes, but it must preserve their ordering and correctness contract.

Online softmax is the mathematical key. When a later tile contains a larger score, previously accumulated exponentials and outputs are rescaled to the new maximum before adding the tile. This preserves numerical stability without retaining every score. Causal masks, sliding windows, dropout during training, and variable sequence lengths alter tile work and must be supported by the selected kernel.

Attention cost is often bytes moved, not FLOPs issuedMaterialized scoreshigh HBM trafficTiled fused kernelbounded trafficKernel fallbackUnsupported shapes can silently restore the slow path.Operational benefitLong contexts fit with lower latency and memory.
Relative bars illustrate data movement, not a universal benchmark ratio.

How to read this diagram: The bars compare Materialized scores with Tiled fused kernel on the article's dominant cost axis. Their lengths are explanatory, not universal benchmark values. The design is worthwhile only when the stated gain, “Long contexts fit with lower latency and memory.”, remains larger than the risk, “Unsupported shapes can silently restore the slow path.”, under production traffic.

Hardware fit matters. Head dimension, dtype, sequence length, GPU generation, causal mode, and layout determine tile shape and occupancy. A theoretically efficient kernel can underperform when register pressure reduces active warps or when short sequences do not amortize launch cost. Dispatch should select among validated kernels rather than assume one implementation wins everywhere.

One row of online softmaxInitializemax=-∞, sum=0Consume tilefind local maximumRescalealign old and newAccumulateupdate output rowThe running maximum and denominator make tile order numerically safe.
The state is tiny compared with a materialized attention matrix.

How to read this diagram: State advances from Initialize to Consume tile, Rescale, and finally Accumulate. The labels below each state identify what becomes true at that boundary. The governing invariant is: The running maximum and denominator make tile order numerically safe. Retries and cancellation must preserve the same transition rules.

Kernel eligibility must be explicitShapeHead width and sequencePadding and ragged layoutSemanticsCausal or window maskBias and positional modePrecisionFP16, BF16, or FP8Accumulator guaranteesHardwareSRAM and register budgetCompiler and GPU targetRecord the selected kernel in traces so regressions are explainable.
FastAttention performance depends on a complete execution contract.

How to read this diagram: The four panels are independent review axes: Shape, Semantics, Precision, and Hardware. A design is incomplete when one panel is optimized while another is left implicit. Use the bottom note as the cross-panel operating rule: Record the selected kernel in traces so regressions are explainable.

The silent fallback failure chainShape changesNew head dimensionKernel is ineligibleFallback runsScores materializeHBM traffic jumpsTail degradesLong prompts dominateOOM risk risesControlAlert by kernel pathGate model rolloutBenchmark every deployed shape and fail validation when an unexpected fallback appears.
Correct output can hide a severe serving regression unless kernel selection is observable.

How to read this diagram: This is a causal chain, not four unrelated symptoms. Shape changes triggers Fallback runs, which creates Tail degrades. The green Control box is the intervention that should break the chain before users observe the final failure. The control must be tested under the initiating condition.

Primary references

The takeaway

FlashAttention is a lesson in systems thinking: the fastest equation is often the one that respects the memory hierarchy.