Early Exit Decoding: Stop Computing Once the Answer Is Clear
A transformer normally sends every token through every layer, even when an intermediate representation is already confident about the next token. Early exit asks a provocative question: can an easy token leave the network before the hardest token would?
This article starts with intuition, then moves into the mechanisms and production details. You can stop after the worked example and retain the core idea, or continue into the performance model and operational edge cases.
Start with the intuition
A medical triage desk does not send every case through every specialist. Routine cases follow a shorter path; ambiguous cases receive the full review. Early exit applies that adaptive depth idea to token prediction.
What actually happens
An early-exit model exposes usable logits at intermediate layers. This usually requires training support, such as auxiliary losses or a shared output head, so early representations learn to predict tokens rather than merely feed later layers.
A controller decides whether to stop. Confidence can come from maximum probability, entropy, margin between top candidates, learned calibration, or a fixed layer schedule. Raw softmax confidence is often miscalibrated, so thresholds must be tuned against task-specific quality.
Self-speculative variants use early layers to draft tokens and the remaining layers to verify them. That can preserve the full model’s behavior more reliably than returning every shallow prediction directly, while avoiding a separate draft-model checkpoint.
A worked example
Take a 32-layer model. Common punctuation and boilerplate may become predictable by layer 12, while code identifiers or reasoning steps need all 32. If half the tokens safely exit after 16 layers, average layer work falls from 32 to 24 layer-token evaluations. Real speedup will be smaller because of control, batching, and memory overhead.
The performance model
Expected work is the sum, over exit depths, of probability of exiting at that depth multiplied by layers executed. Wall-clock speedup also depends on whether a batch can stop individual sequences without forcing all members through the deepest path.
Expert lens
Adaptive depth creates scheduler divergence. A GPU kernel prefers uniform work, while early exit creates per-token variation. Grouping by exit policy, verifying in blocks, or combining early exit with speculative decoding can make the hardware path more regular.
Where it wins
- Models trained explicitly for intermediate exits
- Workloads with many easy or repetitive tokens
- Latency-sensitive systems with measurable quality tolerances
Where it disappoints
- Attaching a classifier to an untrained intermediate layer
- Using confidence thresholds without calibration data
- Reporting layer savings as equal wall-clock savings
- Letting one deep sequence stall a divergent batch
Production checklist
- Use checkpoints trained for early prediction
- Calibrate thresholds on representative tasks
- Keep a full-depth fallback path
- Evaluate rare and safety-sensitive token classes
- Profile batching and kernel divergence
What to measure
- Exit-depth distribution by token and task
- Quality delta against full-depth decoding
- Average layers executed per output token
- Verification rejection rate
- TPOT and throughput under mixed difficulty
From one GPU to a production service
A research checkpoint proves that intermediate layers can predict. A product must decide which errors matter. Exiting one layer early on punctuation is different from exiting early on a medical entity, a code symbol, or a policy decision. Route task class into the exit policy.
Batching introduces a collective decision: allow each token to diverge, regroup sequences by depth, or continue the whole batch to the deepest requested layer. The most accurate per-token controller may be the least efficient GPU schedule.
Deployment needs a kill switch and full-depth comparison stream. Sample a fraction of early-exited tokens for full evaluation, estimate counterfactual quality, and detect calibration drift before users report it.
Design-review questions
- Which token classes are forbidden from early exit?
- How is confidence calibrated after a model update?
- Does divergence erase the saved layer compute?
- What full-depth shadow rate is affordable?
- Which product metric defines acceptable quality loss?
How it connects to the rest of the series
Self-speculative decoding connects early exits to exact target verification. Parallel decoding predicts multiple positions, while continuous batching must absorb the resulting variable work.
From equation to implementation
A controller needs a risk function, not only a confidence number. Exiting at layer d saves the remaining layer cost but incurs expected quality loss conditioned on the token state. The threshold should minimize expected latency subject to a measured error budget, and that budget may differ for code, safety decisions, and conversational filler.
Calibration drifts when prompts, sampling temperature, or model weights change. A threshold tuned on news summarization can be unsafe for mathematical reasoning. Treat exit policy as versioned model metadata and rerun calibration during every checkpoint update.
Implementation sketch
hidden = embed(token)
for layer in layers:
hidden = layer(hidden)
if layer in exit_points:
logits = shared_head(hidden)
risk = calibrated_exit_risk(logits, task_class)
if risk <= policy.max_risk:
return logits, layer.index
return final_head(hidden), final_layerCapacity planning
Exit heads add parameters and bandwidth, while checkpoints trained with a shared head avoid some duplication. Capacity planning must include the deepest-path batch because hard tokens can cluster. Average exit depth does not guarantee peak memory or latency.
Benchmarking without fooling yourself
- Stratify by task difficulty and token category.
- Report quality versus full-depth output, not only a static reference.
- Stress batches where one sequence repeatedly takes the deepest path.
- Recalibrate across temperature and model revisions.
A production failure to design for
A traffic shift introduces many structured JSON outputs. Punctuation exits early with high confidence, but quoted identifiers require deeper context and start failing schema validation. The aggregate accuracy looks stable while operational error rises. Calibrate on product-level validators, not only token accuracy.
Primary references
The takeaway
Early exit turns model depth from a constant into a budget. The engineering challenge is proving when less computation is enough.
