Skip to content
10/20 - Tensor Parallelism: Splitting One Layer Across Many GPUs

10/20 - Tensor Parallelism: Splitting One Layer Across Many GPUs

When one transformer layer is too large for one GPU, tensor parallelism cuts the layer itself into pieces. Every token then crosses multiple GPUs during the same layer, turning fast matrix multiplication into a choreography of compute and collective communication.

This article starts with intuition, then moves into the mechanisms and production details. You can stop after the worked example and retain the core idea, or continue into the performance model and operational edge cases.

Start with the intuition

Imagine several chefs preparing one enormous sandwich. Each owns part of every layer, so they must exchange ingredients before the next layer can begin. More chefs add capacity, but the handoffs can become the meal.

MECHANISM FLOWTensor Parallelism: request path01Input activationShard matrix columnsLocal GEMM per GPU02Collective boundaryAll-reduce or gatherReconstruct dependency03Next sublayerShard rowsContinue togetherINPUT → TRANSFORM → OUTCOME
Follow the state and work from left to right.

How to read this diagram: Start with Input activation, where shard matrix columns. The middle stage, Collective boundary, all-reduce or gather. The final stage, Next sublayer, shows the observable result: shard rows. The arrows describe dependency order, not necessarily separate services.

What actually happens

In a transformer MLP, a column-parallel linear layer splits output features across ranks. Each rank computes a slice. A following row-parallel layer consumes corresponding input slices and combines partial outputs with an all-reduce. Attention projections use related partitioning patterns.

Tensor parallelism reduces per-GPU weight and activation storage but introduces communication inside nearly every transformer layer. It therefore prefers high-bandwidth, low-latency links such as NVLink or NVSwitch and carefully formed process groups.

During autoregressive decode, matrices can be small in the batch dimension. Collective latency becomes visible because it repeats for every layer and token. A tensor-parallel degree chosen only to make the model fit may not be the degree that maximizes tokens per second.

A worked example

Split a 16,384-wide projection across four GPUs. Each rank stores and computes one quarter of the relevant columns. Before a dependent operation needs the complete result, ranks exchange or reduce partial activations. That saves weight memory per GPU but adds a collective on every token path.

The performance model

A simplified layer time is max or sum of local compute and collective time depending on overlap. Collective time includes a latency term plus bytes divided by effective link bandwidth. As local shards shrink, fixed communication latency occupies a larger fraction.

PHASE FITWhere Tensor parallelism changes inferencePREFILLMany prompt tokens in parallelHigh arithmetic intensityLarge GEMMs amortize collectivesDECODEOne new token per iterationWeight and KV bandwidth pressureSmall steps expose collective latencyPROVE IT WITHCollective p99, TTFT, and TPOTDEPLOYMENT DECISIONKeep rank groups on fast fabric
Prefill and decode run the same model but expose different bottlenecks and SLOs.

How to read this diagram: The left panel asks how Tensor parallelism changes prompt processing and TTFT; the right asks how it changes iterative generation and inter-token latency. The bottom row names the metric that must improve and the deployment choice justified by that evidence. Optimizing the wrong phase can add complexity without changing the user-visible bottleneck.

Expert lens

Topology-aware placement matters more than the TP number alone. TP across GPUs behind one NVSwitch is different from TP crossing PCIe hosts or a slower fabric. Keep latency-sensitive TP groups local and scale replicas across nodes when possible.

TRADE-OFF MAPTensor Parallelism: the tradeoffBASELINESingle-GPU layerNo intra-layer collectivesFull weights on one deviceSimple execution pathLimited by one GPU memoryVSOPTIMIZEDTensor-parallel layerWeights split across ranksCollectives every layerMore aggregate memoryTopology-sensitive latencyMEASURE BOTH SIDES UNDER THE SAME WORKLOAD
The optimization changes where the system spends compute, memory, bandwidth, or waiting time.

How to read this diagram: The left panel is the baseline, Single-GPU layer, characterized by no intra-layer collectives and full weights on one device. The right panel applies Tensor-parallel layer, changing the cost profile to weights split across ranks and collectives every layer. Compare both under the same request shape and load; the optimized side is not automatically better for every workload.

Where it wins

  • Layers or models that do not fit one accelerator
  • Large hidden dimensions with fast interconnect
  • Deployments where fewer larger replicas meet demand

Where it disappoints

  • Stretching TP across slow network links
  • Assuming twice the GPUs means twice the speed
  • Ignoring small-batch collective latency
  • Using shard shapes that disable optimized kernels

Production checklist

  • Map TP groups to the fastest local topology
  • Benchmark TP degrees with production batch sizes
  • Inspect collective overlap and NCCL traces
  • Confirm divisible hidden and head dimensions
  • Compare scale-up TP with scale-out replicas

What to measure

  • Collective time per layer and token
  • Link bandwidth and congestion
  • Per-rank compute imbalance
  • Tokens per second per GPU
  • TTFT and TPOT by TP degree

From one GPU to a production service

A workstation test sees one fixed set of GPUs. Kubernetes or a fleet scheduler must preserve the topology every time a replica starts. Device selection, rank order, NCCL configuration, shared memory, and health checks become part of the model deployment specification.

Replica sizing should compare scale-up and scale-out. TP8 may fit the model but offer fewer total tokens per GPU than two TP4 replicas. The right answer depends on memory, queueing, batch efficiency, and whether one replica meets the largest request.

Failure domains are wider than one pod. If any rank fails, the TP replica usually fails. Readiness should be collective, draining should stop new admission across all ranks, and restart policy should avoid partial zombie groups.

Design-review questions

  • What is the smallest TP degree that fits the worst request?
  • Are all ranks inside the intended fabric domain?
  • How does throughput per GPU change with TP degree?
  • What happens to in-flight requests when one rank fails?
  • Would more replicas beat a wider TP group?

How it connects to the rest of the series

Pipeline parallelism splits depth instead of tensors. Sequence parallelism reduces replicated activations around TP. Expert parallelism distributes MoE experts and adds all-to-all traffic.

From equation to implementation

Column parallelism computes Y_i = X A_i for column shards A_i, leaving output features sharded. Row parallelism consumes input shards X_i and computes partial X_i B_i, then sums across ranks. Alternating these patterns reduces unnecessary gathers inside a transformer block.

Collective cost follows topology and message size. Ring all-reduce roughly transfers 2(N-1)/N times the tensor bytes per rank, plus latency across steps. Tree or NVSwitch algorithms differ. The critical point is that decode repeats these collectives for every layer and token.

Implementation sketch

initialize_tp_group(ranks_on_same_nvlink_domain)
for transformer_layer:
    qkv_shards = column_parallel_linear(hidden)
    attention_shard = local_attention(qkv_shards)
    hidden = row_parallel_linear_and_all_reduce(attention_shard)
    mlp_shards = column_parallel_linear(hidden)
    hidden = row_parallel_linear_and_all_reduce(mlp_shards)

Capacity planning

Choose TP primarily to fit weights, KV shards, and workspace, then test the next smaller and larger degrees. More TP reduces bytes per rank but may shrink GEMMs below efficient tile sizes and add ranks to every collective.

Benchmarking without fooling yourself

  • Pin rank placement and record the exact topology.
  • Test batch-one decode and realistic continuous batches.
  • Capture NCCL traces and per-rank kernel gaps.
  • Report tokens per second per GPU, not aggregate alone.

A production failure to design for

A Kubernetes reschedule places half of a TP group across a slower inter-node link. The model remains healthy, but TPOT triples and collective time dominates. Enforce topology constraints and expose TP-group locality in readiness checks.

OPERATING LOOPOperational loop1PlaceFast local topologyStable rank group2FitWeights KV workspaceSmallest viable TP3TraceCompute versus NCCLRank imbalance4ScaleReplicas across nodesTP within nodeMEASURE → LEARN → REPEAT
Treat optimization as a measured loop, not a one-time flag.

How to read this diagram: The operating cycle moves from Place to Fit, then Trace and Scale. The return arrow matters: production evidence from the fourth step must change the assumptions and limits in the first, otherwise the optimization gradually drifts away from the workload it serves.

Deeper engineering guide

Tensor parallelism shards the matrices inside a transformer layer. Column-parallel projections split output features; row-parallel projections split input features and reduce partial outputs. Attention heads and MLP intermediate dimensions are common shard axes. Every layer therefore alternates local GEMMs with collective communication.

One tensor-parallel transformer layerShard inputReplicate or scatterAlign hidden stateEnter local GEMMLocal computeEach GPU owns weightsProduce partial outputUse tensor coresCollectiveAll-reduce or gatherSynchronize ranksMove activation bytesNext blockRestore expected layoutRun following opRepeat every layerCommunication is on the token critical path and repeats across model depth.
Tensor parallelism buys model capacity by inserting collectives inside each layer.

How to read this diagram: Follow the state from Shard input through Local compute and Collective to Next block. Each box is an ownership or computation boundary. In particular, communication is on the token critical path and repeats across model depth. A real implementation may fuse boxes, but it must preserve their ordering and correctness contract.

The compute-to-communication ratio worsens at small batches and decode batch-one. Prefill provides larger matrices that amortize collectives; decode may spend a large fraction of time in latency-bound all-reduces. Fast intra-node fabrics help, while crossing ordinary Ethernet can erase the gain.

Scaling stops when collectives dominateLocal GEMM workshrinks per rankCollective costrepeats per layerScaling limitMore ranks can add latency after local tiles become too small.Capacity benefitLarge weights and KV heads fit across device memory.
Illustrative balance; topology and batch shape determine the crossover.

How to read this diagram: The bars compare Local GEMM work with Collective cost on the article's dominant cost axis. Their lengths are explanatory, not universal benchmark values. The design is worthwhile only when the stated gain, “Large weights and KV heads fit across device memory.”, remains larger than the risk, “More ranks can add latency after local tiles become too small.”, under production traffic.

Shard placement must follow topology. Keep a tensor-parallel group inside the fastest fabric domain, pin ranks deterministically, and verify peer-to-peer access. A single slow link delays the collective for every rank. Health checks should include collective microbenchmarks rather than only GPU visibility.

Tensor-parallel request synchronizationReadyall ranks admittedComputinglocal tiles runCollectiveranks exchange dataAdvancedlayer completesA timeout or failure on one rank fails the coordinated request across the group.
The serving unit is the rank group, not an individual GPU process.

How to read this diagram: State advances from Ready to Computing, Collective, and finally Advanced. The labels below each state identify what becomes true at that boundary. The governing invariant is: A timeout or failure on one rank fails the coordinated request across the group. Retries and cancellation must preserve the same transition rules.

Four dimensions of TP efficiencyShard axisHeads or hidden widthDivisibility constraintsTopologyNVLink or networkRank localityWorkloadPrefill and decode shapeConcurrent sequencesRuntimeCollective overlapKernel fusionBenchmark p50 and p99 collective time at deployed message sizes.
Tensor parallelism is a distributed kernel, not merely split weights.

How to read this diagram: The four panels are independent review axes: Shard axis, Topology, Workload, and Runtime. A design is incomplete when one panel is optimized while another is left implicit. Use the bottom note as the cross-panel operating rule: Benchmark p50 and p99 collective time at deployed message sizes.

One degraded link slows every tokenLink degradesOne rank transfers slowlyNo process crashesCollectives waitAll ranks synchronizeGPU compute idlesTail explodesEvery layer repeats delayQueue growsControlProbe fabric latencyDrain rank groupGroup-level health must detect performance failure before availability failure.
Collective tail latency compounds through model depth.

How to read this diagram: This is a causal chain, not four unrelated symptoms. Link degrades triggers Collectives wait, which creates Tail explodes. The green Control box is the intervention that should break the chain before users observe the final failure. The control must be tested under the initiating condition.

Primary references

The takeaway

Tensor parallelism buys memory and compute with communication. The winning configuration is the smallest TP degree that fits and performs well on the actual topology.