Distributed Machine Learning: How Does it Work?

State of the art AI requires orchestrating large clusters to perform a single synchronised calculation. How does this orchestration work? And how can it be done without incurring expensive communication overheads?

Bruce Mauger
01-01-2025

Many recent advances in AI are owed to training of larger and larger models. These neural networks are too big fit in and take too long to train with a single GPU. State of the art AI therefore requires orchestrating vast clusters of GPUs to perform a single synchronised calculation, with the objective of:

  1. Reducing memory impact, so we can fit larger models.
  2. Increasing degree of parallelism, so we can use lots of compute in parallel to speed up training.

There’s no such thing as a free lunch: by distributing training we incur a communication overhead when GPUs have to talk to each other. Meta’s breakthrough Llama 3 model was trained1 on a cluster of 24K GPUs, but with a per-GPU utilisation well below 50% [1].

We’ll be taking an in-depth look at the wide variety of parallelism paradigms that have enabled training of gigantic models. As we’ll see, it’s relatively easy to design parallelism techniques that achieve both of the previous objectives; doing so without incurring prohibitive communication overheads, however, is a very difficult engineering and research challenge. The communication aspects of these designs will be focus of this post.

We assume only basic prior knowledge of neural networks and gradient descent. The first two sections provide background on the PyTorch framework (as well as the transformer model that underpins modern LLMs), and collective communication primitives respectively. The next section presents a deep-dive into all of the major parallelism techniques, including how PyTorch non-intrusively integrates these into its inherently local execution model. We’ll conclude by looking at how parallelisms are composed together in practice, in particular for Llama 3.

1. PyTorch mental model

We’ll be using PyTorch to illustrate model distribution throughout (though everything remains largely applicable to other frameworks). It’s useful to first get an understanding of how data flows during the PyTorch model training process, in order to spot opportunities for parallelisation.

How do you train a PyTorch neural network?

[Further Reading]

PyTorch’s fundamental data structure is the Tensor, a multi-dimensional matrix (think NumPy’s ndarray) used to store a model’s parameters and encode inputs/outputs. In PyTorch, a neural network is a Module composed by stitching other modules (layers) and functions together. For example, here’s a simple network with two linear layers and a ReLU activation function in-between:

from torch import nn

class NeuralNetwork(nn.Module):
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(2, 4)
    self.activation_fn = nn.ReLU()
    self.linear2 = nn.Linear(4, 1)
  
  def forward(self, x):
    x = self.linear1(x)
    x = self.activation_fn(x)
    x = self.linear2(x)
    return x

Unsurprisingly, forward defines the network’s forward pass: how inputs are mapped to outputs. Here a 2D input is mapped to a 1D output, with a 4D “hidden” layer. Taking the first Linear submodule as an example, it holds weight and bias tensors of shapes [4,2] and [4] respectively. Adding the second linear layer’s parameters (the activation function doesn’t have any), we can see the network has a total of 17 trainable parameters.

This is all well and good, but we can’t actually train the network yet! For that we need a basic training loop:

model = NeuralNetwork() 
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
X = torch.ones(10, 2) # input batch tensor
y = torch.zeros(10, 1) # expected output (target) batch tensor

epochs = 10
for t in range(epochs): 
  # Compute prediction and loss
  pred = model(x)
  loss = torch.nn.functional.cross_entropy(pred, y)
  
  # Update parameters
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()

We train our model for 10 epochs (iterations) over a single batch of \(B=10\) (identical) data samples2. In each epoch:

  1. With model(x), we call the forward method defined earlier to obtain predictions for the entire input batch. The outputs of each layer (“activations”) are cached for use in the backward pass.
  2. We compute the (cross entropy) loss for these predictions and store them in the loss tensor.
  3. We calculate the derivative of the loss of each sample with respect to each parameter (“gradients”) with loss.backward(). PyTorch’s autograd does this automatically by building a computational graph in the forward pass, and then applying backpropagation starting from the outer layer in the backward pass. It accumulates gradients in each tensor’s .grad attribute3.
  4. The optimizer defines how parameters are updated from gradients. optimizer.step() performs this adjustment, and optimizer.zero_grad() resets gradients so they don’t accumulate in the next pass.
Pebble graph for a four layer network illustrating how cached activations are built up in the forward pass, and used to calculate gradients in the backward pass (graphic inspiration).

This process is known as stochastic gradient descent: we iteratively update parameters using gradients calculated over a random subset (batch) of the entire dataset.

For each parameter in our network, we also need to store its gradient and relevant optimizer state. The popular Adam optimizer tracks momentum and variance, exponential averages of the first and second moments respectively of each parameter’s gradient [2]. The result is that each parameter can end up needing at least 16 bytes of memory, mostly attributable to high-precision optimizer state (assuming fp16/32 mixed precision4) [3]. For larger models such as Meta’s Llama 405B that’s a 6.5TB memory requirement, which makes distributing model parameters over several GPUs a necessity.

PyTorch offers two execution models: eager mode and graph mode. In eager mode (the default), operators are immediately executed as they are encountered – effectively, we can’t “look ahead”. Graph mode synthesises operators into a graph, which is then compiled and executed as a whole. As of PyTorch 2.5, most of the parallelisms offered only exist in eager mode – which, as we’ll see, can often lead to suboptimal sequences of operations.

The Transformer architecture

[Further Reading]

[WIP] As the largest models trained today are transformers, most of the distributed training literature has evolved around this architecture. Transformers are usually made of several equal-size transformer blocks, which are very convenient for splitting/parallelisms.

2.Communication Primitives

[Further Reading]

Before going into distribution strategies, we need to discuss the primitives we have available for communicating data between GPUs.

Let’s start with a simple model: two GPUs (or “ranks”) with a point-to-point (p2p) connection – this could be a fast NVLink interconnect if they’re within the same host, or a slower InfiniBand or Ethernet network (perhaps with several hops) if they’re not.

All primitives operate over a single tensor at each rank. The simplest thing we can do is to send a tensor from one rank and receive on the other:

Now let’s suppose we want to synchronise tensors distributed over a group (or “collective”) of GPUs. One way to do this is with an AllToAll collective, a complete graph of p2p sends:

This isn’t very bandwidth efficient: a world-size of \(W\) ranks synchronising \(D\)-sized tensors results in \(D(W-1)\) per-GPU traffic, some of which may be contending for the same underlying network links. Moreover, we often only need an aggregate of the distributed tensors – for example we might want to average some parameters we’ve replicated across the ranks. So how might we accomplish this with less bandwidth? If each rank reduces (applying an associative5 operator e.g. sum, min, max, etc.) the tensor it receives with its own local tensor, before passing the result onto the next rank, we obtain a ring-based Reduce collective:

After completing one loop around the ring, we’ve reduced all of the tensors into a single tensor – but this result is only held in the last rank. We need to complete another loop so that each rank holds a replica of the resulting tensor. This is the Broadcast collective:

Notice that, in the latter two collectives, only one rank/link at a time is busy, with the rest idle. We can use pipelining to get better throughput: we split the tensor into \(W\) chunks, with the \(r^\text{th}\) rank at the start (or root) of the ring corresponding to the \(r^\text{th}\) chunk. The pipelined analogs of Reduce and Broadcast are ReduceScatter and AllGather respectively. Sequencing the two together results in the composite AllReduce collective:

The ReduceScatter and AllGather collectives correspond to the first and second loops in the above animation. Notice we obtain the same result we would have had with an AllToAll followed by local reductions at each rank. However, with its use of a ring, AllReduce improves communication overhead by an order of magnitude. Each GPU will send a \(\frac{D}{W}\)-size datachunk \(W-1\) times for the ReduceScatter and \(W-1\) times for the AllGather, for a total per-GPU traffic of \(2(W-1)\frac{D}{W}\). Crucially, this is independent of the number of GPUs in the collective!

Though Ring AllReduce is bandwidth optimal, its end-to-end latency scales linearly with the number of ranks. A lower latency, tree-based alternative will be discussed in another post.

3. Parallelism Paradigms

Now that we have an understanding of how data flows during the PyTorch model training process and the primitives we have available for communicating this data between GPUs, let’s look at the various techniques that have emerged for distributing training. In order to have some notion of correctness for these techniques, we’ll define a distributed algorithm to be locally consistent if it is mathematically equivalent to local training.

Distributed Data Parallel (DDP)

[Further Reading]

As its name would imply, DDP splits our dataset across ranks (each with an identical copy of the model), with periodic synchronisation to ensure model replicas are consistent. DDP is useful when our model is still small enough to fit on a single GPU, but we’d like to speed up training by having several GPUs work on a single batch in parallel.

We described local training in Section 1: at each iteration we load the next batch, perform a forward pass while caching each layer’s activations, and calculate the loss. Then we run the backward pass to calculate gradients, before our optimizer updates parameters.

Local training example on a batch of 6 MNIST data samples (image credit).

DDP duplicates the model across \(W\) ranks, splitting batches into \(W\) different \(\beta=\frac{B}{W}\)-size chunks for each rank to process:

Data parallel training example with W=2 ranks (image credit).

Without any communication overhead, this should result in a linear \(W\times\) speedup. The forward and backward passes are independent sample-wise calculations6, and hence our batches can be independently processed without any communication.

To achieve local consistency, we need to synchronise our gradients before the optimizer step so that the weight updates at each rank are the same. Conveniently, the most commonly used loss functions are means over the sample-wise losses in the batch:

\[ \text{loss}(\text{batch}) = \frac{1}{B} \sum_{j=0}^B \text{loss}(\text{fwd}(\text{input}_j), \text{target}_j) \] Because the gradient of a sum is the sum of the gradients of each term, we can calculate gradients for the chunks at each rank independently and average them together to obtain the gradient over the entire batch:

\[ \nabla \theta = \frac{1}{W}\sum_{r=0}^W \nabla \theta_r^\text{local} \] This can be done efficiently using the previously discussed AllReduce collective (along with a single Broadcast from the root rank after model construction, to synchronise initial parameters).

Lastly, it’s worth noting that batchsize limits the maximum degree of DDP parallelism. We need to maintain chunk sizes bigger than a minimum batchsize per GPU (\(\beta_\text{min}\)), below which compute intensity/utilisation decreases substantially. On the other hand, stochastic gradient descent is maximally efficient when overall batchsize is well-below an empirical value known as the critical batch size7, \(B \ll B_\text{crit}\) [4]. This small batchsize regime is unattainable for larger clusters since \(B \geq W\beta_\text{min} > B_\text{crit}\).

DDP in PyTorch

PyTorch’s distribution API is designed for non-intrusive scaling out from local training. Applying DDP to a model is as simple as wrapping our local model with the DDP nn.Module class: nn.parallel.DDP(model, process_group=...). The process group (PyTorch’s abstraction for a group of processes that run collectives together) allows us to specify what communication backend to use, and which ranks to distribute over. Care should be taken to ensure the batches processed by each rank are different (e.g. with the DistributedSampler dataloader class).

A naive implementation of DDP would synchronise gradients only after running a full forward and backward pass, and then subsequently calling optimizer.step(). This is suboptimal as it divides training into two distinct phases: one where we’re waiting for backpropagation to finish computing while the network is idle, and another where the network is communicating as fast as possible while our expensive GPUs are doing (almost) nothing:

Naive DDP implementation with non-overlapping computation and communication (image credit).

Notice in the above that gradients for later layers are already available while we’re still computing the backward pass of earlier layers. For example, the gradients of Layer3 are ready while we’re backpropagating through Layer2. This allows us to overlap computation with (non-blocking) communication, speeding up the complete iteration:

Faster DDP implementation with overlapping computation and communication (image credit).

Collective communications are more efficient on large tensors. Therefore, in practice, rather than launching a dedicated AllReduce immediately as soon as a layer’s gradient tensor is ready, we use Gradient Bucketing: we wait for a short period and bucket multiple tensors at a time into one AllReduce.

To non-intrusively integrate with its eager execution model, PyTorch implements DDP by registering one autograd hook (a callback) with each parameter tensor, which fires after the corresponding gradients are updated (during the loss.backward() call). Once all hooks in a bucket8 have fired, an asynchronous AllReduce is triggered. PyTorch’s DDP paper [5] shows interleaving brings significant performance gains, particularly when using the recommended NCCL communication backend:

Per-iteration normalised latency breakdown, comparing non-overlapping vs overlapping communication; training on 32 GPUs across 4 machines. Figure from [5].

Amortised communication overhead can be further reduced with Gradient Accumulation: rather than synchronising gradients every iteration, we accumulate (via the no_sync context manager) the gradients of \(n\) local training iterations before synchronising gradients globally and updating parameters. [5] claims this enables near-linear scaling for smaller GPU clusters, with “negligible accuracy penalty”:

Per-iteration latencies (left), and final training loss (right) for n iterations of gradient accumulation. Figure from [5].

Fully-Sharded Data Parallel (FSDP)

[Further Reading]

DDP speeds up training by distributing our dataset across multiple ranks, but what happens when our model can’t fit within a single GPU? DDP’s newer alternative, FSDP, addresses this by also splitting model parameters. FSDP is a PyTorch native implementation of DeepSpeed’s ZeRO [6], with some further optimisations.

FSDP reduces memory footprint by sharding model parameters: the model is split horizontally so that each rank only holds a subset (“shard”) of the parameters (and associated gradients and optimizer state) in any given layer. The naive approach to guaranteeing local consistency is to compute the partial activations of a layer corresponding to the local shard, and then to communicate these activations with the other ranks before proceeding onto the next layer:

Naive FSDP forward pass: activations are data-dependent and therefore appear on the critical path.

The obvious problem with this approach is that communication appears on the critical path: we can’t compute the forward pass for a given layer until we’ve received the complete activations of the previous layer.

Instead of communicating activations, FSDP’s approach is to communicate parameters. FSDP fully materialises parameters before computations, just as in local training, thus removing any data dependency. However, we would need to be able to materialise parameters on a single GPU, eliminating our memory savings! FSDP’s simple solution is to partition the model into groups of layers called units, only instantiating one unit at a time on-demand.

So what does this look like in practice? Let’s look at a simple six layer model (illustrated below), which we’ve decided to decompose into three units: [layer0, layer3], [layer1, layer2] and [layer4, layer5]. Consider what happens to unit1 consisting of [layer1, layer2]:

  1. Just before the forward pass through layer1, we materialise the parameters in unit1 by gathering shards from peer ranks. We can do this with an AllGather (equivalent to each rank Broadcasting its own shard).
  2. After completing local forward computation, we free peer shards (but keep activations).
  3. Before the backward pass through layer2, we AllGather the shards again.
  4. After gradients are calculated, we free peer shards and then ReduceScatter to sum up and shard gradients (equivalent to each rank Reducing the gradients in its shard).
  5. Finally, after completing full forward & backward passes through all units, we update our shard of the parameters in the optimizer step9.
FSDP example with three units, fully sharded over two ranks. Figure from [7].

In effect, FSDP decomposes DDP’s AllReduce into a ReduceScatter and an AllGather in the backward and forward passes respectively – the only extra communication incurred is when we AllGather parameters again during backpropagation.

Sharding Strategies

FSDP enables fine-grained trade-offs between memory footprint and communication overhead via the sharding factor \(F\): the number of ranks over which parameters are sharded. By setting \(F=W\) (i.e. the global world size), FSDP fully shards the model with each rank holding only \(\frac{1}{W}\) of the model (as in the above example, with \(F=W=2\)).

Hybrid sharding, sharding factors ranging between \(1\) and \(W\), combines both sharding and replication. We end up with sharding groups \(S_1, \ldots, S_\frac{W}{F}\), each consisting of \(F\) ranks over which parameters are sharded, and replication groups \(R_1, \ldots, R_F\) (directly corresponding to these shards), each consisting of \(\frac{W}{F}\) ranks (one from each sharding group) over which shards are replicated.

The AllGather+AllGather+ReduceScatter collectives, previously over all ranks, are now collectives within each sharding group, followed by an AllReduce within each replication group to synchronise gradient shards (as in DDP). This is effectively the decomposition: \[ \nabla \theta = \frac{1}{W}\sum_{r=1}^W \nabla \theta_r^\text{local} = \frac{1}{W}\sum_{i=1}^{W/F}\sum_{r \in S_i}\nabla \theta_r^\text{local} \] For example, with \(W=16\) ranks and \(F=8\) hybrid sharding, the \(r=9\) rank would AllGather parameters and ReduceScatter its gradient shard with peers in the \(S_2\) sharding group, before AllReducing the gradient shard with its peer in the \(R_2\) replication group:

FSDP Hybrid Sharding (F=8) example with W=16 ranks.

You might’ve spotted that setting \(F=1\) results in a single replication group (with no memory savings) – this simplifies to vanilla DDP using AllReduce for gradient synchronisation. It’s worth noting that with any sharding strategy, ranks are expected to have distinct input batch chunks (otherwise we’d simply be duplicating gradient calculations). Though we can keep reducing memory overhead with larger sharding groups, our degree of compute parallelism will encounter the same batchsize limitations as DDP.

Using our traffic calculations from Section 2, the per-GPU communication of an \(M\)-size model is \(2(\frac{W}{F}-1)(\frac{M}{W})\) for the replication group, and \(3(F-1)(\frac{M}{F})\) for the sharding group. Because communication within the sharding group is more expensive (and intertwined with the critical path), we usually try to minimise the number of hops between the ranks in a sharding group – sometimes we may even use smaller sharding factors to ensure they’re within the same host.

FSDP in PyTorch

Just like DDP, the FSDP API is designed as a thin nn.Module wrapper class: sharded_model = FSDP(model, process_group=...)10. Sharding strategy is set with the sharding_strategy arg: FULL_SHARD, NO_SHARD and HYBRID_SHARD correspond to aforementioned fully sharded, fully replicated and hybrid strategies respectively11. Before going into all the other levers that FSDP exposes to the user, let’s first get a quick understanding of how it’s implemented under the hood.

The communication backends (e.g. NCCL) that provide collective implementations usually require AllGather and ReduceScatter to have the same input tensor size at each rank. Moreover, for a fixed communication volume issuing fewer, larger collectives reduces communication overheads (as discussed in DDP’s Gradient Bucketing). Thus, during construction FSDP concatenates all parameters (and gradients) within a unit into a single flattened 1-D FlatParameter tensor, along with the padding necessary to ensure equal-sized shards at each rank in the sharding group12. The FlatParameter tensor has the exact data layout expected by AllGather and ReduceScatter, allowing us to call the collectives directly without copying any tensors.

FlatParameter example for a fully sharded (W=F=16) FSDP unit, consisting of one 4 \times 3 nn.Linear layer. Figure from [7].

For an \(M\)-size model split into \(K\) units with sizes \(M_1, \ldots, M_K\), where \(\sum_{i=1}^K M_i=M\), the maximum memory usage is in \(O(\frac{M}{F} + \max_{i=1}^K M_i)\). More precisely, it is the sum of the sharded parameters, gradients and optimizer state, combined with the largest unsharded unit’s parameters and gradients (but not the more expensive optimizer state, which always remains sharded). Conversely, even though total communication is not affected by the number of units, the number of collectives over which it is spread is \(O(K)\). Therefore the number of units presents yet another memory-communication tradeoff. PyTorch lets the user control this with the auto_wrap_policy argument to FSDP, or by manually wrapping individual submodules rather than a single wrapper around the entire model13.

As with DDP’s Gradient Bucketing, FSDP tries to overlap communication and computation as much as possible. Here’s what that looks like for our previous three unit, six layer example:

Full forward & backward pass for previous 3 unit, 6 layer FSDP example. Compute & communication CUDA streams (below), and broken up by unit (above).

In DDP’s backward pass, we were able to compute gradients and then asynchronously AllReduce them afterwards. This isn’t possible for FSDP’s forward: we need to AllGather parameters before computing, and (because of eager execution) we don’t know which FlatParameter to gather next – thus we can’t reorder the async AllGather of the next unit before the synchronous computation of the current unit. The solution, implicit forward prefetching (always enabled), is to use a separate stream (queue of device instructions) for communication, bypassing the false dependency on the default compute stream.

You may have noticed the poor compute-communication overlap in the backward pass: the ReduceScatter for the current unit blocks the AllGather for the next, which in turn blocks the next gradient computation14. Explicit backward prefetching issues the AllGather for the next unit before the ReduceScatter for the current one. To know which FlatParameter to gather next, FSDP records the reverse forward execution order of modules each iteration. Two variants exist: backward_prefetch=BACKWARD_PRE which overlaps the next AllGather with the current gradient computation, and BACKWARD_POST which waits until the current parameters are freed (using less memory but reducing overlap). By default FSDP limits the rate at which prefetch AllGathers are issued to ensure memory usage of at most two consecutive units.

FSDP makes one final optimisation: it assumes the root unit (wrapping the outermost module) holds the last layer’s parameters, and does not free the root unit’s parameters after the forward pass (with the intention that they are immediately re-used for backward). Because this naively sidesteps eager execution, it doesn’t always work. In our example, it’s actually unit 3 that holds the last layer and we end up AllGathering parameters we already have and are about to free!

Lastly, we should note that with hybrid sharding there would also be an async AllReduce (on yet another communication stream) for each unit after their ReduceScatter is done.

PyTorch’s FSDP experiments [7] show near-linear compute scaling, though this regresses substantially for larger clusters where “a near-perfect overlap between communication and computation is no longer attainable”:

Fully-sharded training of the T5-11B transformer; TFLOPS per GPU for batchsizes 8 and 16; A100 80GB GPUs with 2Tb/s RoCE interconnects. Figure from [7].

Pipeline Parallel (PP)

[Further Reading]

Like FSDP, pipeline parallelism aims to train models too large to fit within a single GPU. However, rather than sharding the model horizontally, we partition it vertically along its depth. Each partition of consecutive layers is referred to as a stage. Returning to our four layer network from Section 1, we could partition it evenly across two ranks and send intermediate activations/gradients at partition boundaries between stages:

Pebble graph of a four layer network, partitioned across two ranks into two stages of two layers each (image credit).

This naive approach of passing a single batch from rank to rank (often referred to as “model parallelism”), results in severe GPU under-utilisation: only one GPU works on the batch at any given moment, so each rank is busy at most \(\frac{1}{W}\) of the time. To illustrate, here’s what the same naive schedule would look like with a pipeline depth of four stages:

Naive model parallelism with d=4 stages; FWD/BWD are over entire stages rather than only a single layer.

These dead zones in our schedule where GPUs are idle are called pipeline bubbles. They are caused by dependencies between operations: for example, rank 2 cannot start the 2nd forward stage until it has received 1st stage intermediate outputs from rank 1.

GPipe [8] reduces bubbles by splitting a batch into microbatches and adding up each of their gradients to get back the gradient over the entire batch (the same as DDP gradient accumulation), thus allowing more than one rank to do useful work at the same time. Here’s the same four stage example, with 4-way batch-splitting:

GPipe schedule with d=4 stages, m=4 microbatches; Fi(j) denotes the i^\text{th} stage forward computation over the j^\text{th} microbatch.

We can show empirically that, in this example, GPipe reduces relative bubble time by more than two thirds. For a pipeline with \(d\) evenly-partitioned stages and \(m\) evenly-divided microbatches, a given stage spends \(m\) forward timesteps doing useful work and \(d-1\) forward timesteps waiting for new work to arrive during the forward pass. The same applies to the backward pass, therefore the overall fraction of ideal computation time spent in bubbles is:

\[ \text{Bubble}(d,m)=\frac{d-1}{m} \] For the bubble time fraction to be small we need \(m \gg d\). Our naive model parallelism example (\(m=1, d=4\)) has a bubble ratio of \(3\), compared to \(0.75\) for GPipe (\(m=4, d=4\)).

So far we’ve ignored communication overheads. Unlike other parallelisation paradigms, pipelining does not require any collective communication primitives; we simply asynchronously send (p2p) intermediates as soon as they’re ready. Here’s what our GPipe example looks like once we include communication:

GPipe schedule with d=4 stages and m=4 microbatches, communication included.

For illustration purposes, here sending a microbatch takes longer than computing a stage (pipelining is typically internode so this is not uncommon), reducing our compute efficiency. Perfect compute-communication overlap is impossible for pipeline parallelism because necessarily we can’t start working on the first microbatch until the previous stage has finished processing and then sent the same microbatch.

Notably, pipeline parallelism is orthogonal to DDP and both can be combined to obtain a 2D parallelism similar to hybrid FSDP. In practice, this is implemented with the pipeline as the inner dimension and with bucketed AllReduces in the outer dimension (interleaved with the backward pass of the final microbatch).

DDP and pipeline 2D parallelism example over W=4 ranks, d=2 stages (image credit).

Activation Checkpointing

[Further Reading]

While we’ve discussed ways of reducing memory demand, you may have spotted another easy target: activations. With GPipe, stages need to cache activations for each microbatch from the start of its forward to the end of its corresponding backward. For an \(\ell\) layer network (assuming each layer is roughly equal size) with batchsize \(B\), the peak per-stage memory demand for caching activations is15: \[ O\left(B \frac{\ell}{d}\right) \] With activation checkpointing16 (aka gradient checkpointing) [9], we only store boundary activations and recompute the forward for each microbatch when it’s time to do its backward. Boundary activations take \(O(B)\) space and we only need to cache activations for a single microbatch at any given moment (while computing its gradient), reducing peak memory demand to: \[ O\left(B+\frac{B}{m}\frac{l}{d}\right) \] Why can we get away with recomputing the forward without significantly impacting overall compute efficiency17?

The original GPipe paper [8] claims that, with activation checkpointing, using \(m \geq 4d\) microbatches results in “negligible” bubble overhead.

Other Schedules

[Further Reading]

We can further decrease memory demand by reducing the number of “in-flight” microbatches for which we need to cache activations (or checkpoints). If you look at the original GPipe schedule, after completing the forward for the first microbatch, the last stage could instead start the backward pass right away and then discard its activation. PipeDream [11] schedules the last stage backward immediately after the corresponding forward, reducing memory demand compared to GPipe:

1F1B PipeDream-flush schedule, ignoring communication with d=4, m=8. Based on figure from [12].

The schedule consists of three phases:

  1. The warmup where deeper stages are waiting on activations from earlier stages. We limit the number of contiguous microbatches over which we compute a forward pass to the pipeline depth, thus also limiting the number of in-flight microbatches.
  2. The steady state where ranks perform one forward pass followed by one backward pass (known as “1F1B”).
  3. Lastly we flush the pipeline by completing the backward passes for remaining microbatches without scheduling any new ones.

Unlike GPipe, where all microbatches are in-flight at some point during the schedule, with PipeDream that number never exceeds the pipeline depth. Since reducing bubble time requires \(m \gg d\), PipeDream can do so without affecting memory footprint. However, for fixed \(m\), the bubble fraction is no different between PipeDream and GPipe (you can see this by shifting all the blue forwarded passes left). So how can we do better, without prohibitively increasing \(m\) and thus the overall batchsize?

Rather than each rank only having a single stage (of consecutive layers), we can loop our pipeline by performing the computation for \(v\) (non-consecutive) stages at each rank and connecting the last rank to the first, forming a coil. For example, if rank 0 previously had 4 layers (e.g. layers 0-4), with \(v=2\) loops it would now have 2 stages of 2 layers each (e.g. layers 0-1 and 8-9). Interleaved 1F1B [12] is the looped version of the PipeDream schedule:

Interleaved 1F1B (DFS) schedule; W=4 ranks, d=16 stages, v=4 loops, m=8 microbatches. Data-parallel AllReduce illustrated on odd rows, pipeline-parallel communication omitted. Figure from [13].

Forward and backward passes for a stage are now shorter by a factor of \(v\), so the bubble time at each rank is also reduced by \(v\) to \((d-1)/v\). The overall bubble fraction becomes: \[ \text{Bubble}_\text{looped}(d,m,v) = \frac{d-1}{vm} \] This reduced bubble size does not, however, come for free: the total communication volume is increased by the same factor of \(v\). Another caveat is that Interleaved 1F1B requires \(m\) be a multiple of \(W\).

You can imagine that 1F1B schedules like PipeDream and Interleaved 1F1B are depth-first: when deciding between computing the same (forward) stage for the next microbatch or the next (backward) stage for the same microbatch, a rank will choose the latter – sending earlier microbatches as deep as possible down the pipeline. “All-forward all-backward” schedules like GPipe are breadth-first: a rank will prioritise completing all microbatches in the earliest unfinished stage. Naturally, you might ask what the looped anolog of GPipe looks like. This is the breadth-first pipeline schedule (BFS) [13]:

Breadth-first pipeline schedule; W=4 ranks, d=16 stages, v=4 loops, m=8 microbatches. Data-parallel AllReduce illustrated on odd rows, pipeline-parallel communication omitted. Notice that a complete iteration finishes faster than Interleaved 1F1B. Figure from [13].

The bubble time fraction is exactly the same as Interleaved 1F1B (“DFS” from now on), though (like GPipe) peak memory consumption is increased as all microbatches are in-flight at some point. However, BFS achieves much better communication overlap:

More importantly, BFS can make use of FSDP. In a non-looping pipeline, a rank only has a single stage. To achieve any memory savings, FSDP units would have to be intra-stage thus requiring a complete FSDP iteration for each microbatch. With a looping pipeline, ranks hold several stages each of which can be used as a unit; we can keep a stage unsharded and accumulate gradeints for contiguous microbatches, only resharding/reducing at the end of the microbatch sequence. A BFS stage completes the forward (or backward) pass for all microbatches in one go, avoiding the repeated FSDP iterations that arise from alternating between stages in DFS.

The benefits of combining BFS and FSDP are two-fold: first, FSDP compensates for BFS’ larger memory footprint by sharding stages; second, BFS reduces the size of FSDP sharding groups (and their expensive collectives) while maintaining the same overall degree of parallelism.

Like DDP, FSDP is typically the outermost parallelism when combined with pipelining. Outer dimensions may spread across a multi-hop network with higher communication latency and lower bandwidth; FSDP has fewer, larger asynchronous collectives that can better absorb these communication delays.

There are a lot other pipeline schedules we haven’t covered here, a good reference is the PyTorch documentation. More advanced scheduling strategies, such as those used to train Llama 3 [1], include hybrid schedules that combine the memory savings of DFS with the communication efficiency of BFS.

PP in PyTorch

[WIP]

Tensor Parallel (TP)

[Further Reading]

Tensor parallelism is similar to FSDP: we split our model horizontally to reduce memory footprint and increase our degree of parallelism. However, TP shards are much more granular, with splits within a single layer rather than across units of several layers.

The computational bottleneck for most modern models (such as the transformer) is the general matrix multiply (GEMM): multiplying an activation batch matrix \(X\) with a large weight matrix \(A\). The example we used in Section 1, a multi-layer perceptron (MLP), consists of a GEMM followed by a pointwise nonlinear activation function \(\sigma(\cdot)\): \[ Y = \sigma(XA) \] One way to paralellise the GEMM would be to split the weight matrix \(A\) along its rows, and input \(X\) along its columns: \[ X = \begin{bmatrix} X_1 & X_2 \end{bmatrix}, A= \begin{bmatrix} A_1 \\ A_2 \end{bmatrix} \] Matrix multiplication can be thought of as a dot product between pairs of rows (on the left) and columns (on the right). It’s therefore possible to compute independent dot products on different ranks and sum up the results: \[ Y=X_1A_1+X_2A_2 \] However, because \(\sigma\) is a nonlinear, in general \(\sigma(X_1 A_1 + X_2A_2) \neq \sigma(X_1 A_1) + \sigma(X_2 A_2)\). This approach would therefore require a synchronisation point before the activation function. If instead we only split weights along their columns \(A=\begin{bmatrix} A_1 & A_2 \end{bmatrix}\), we end up concatenating rather than adding the outputs: \[ Y=\begin{bmatrix} Y_1 & Y_2 \end{bmatrix}=\begin{bmatrix} \sigma(X A_1) & \sigma(X A_2) \end{bmatrix} \] By removing the pre-activation synchronisation, we can stack a second MLP layer before we have to synchronise again: \[ Z = \sigma_1(YB)=\sigma_1(Y_1B_1 + Y_2B_2) \] The second layer weights are split column-wise, \(B= [B_1; B_2]\), and we synchronise before the second activation function \(\sigma_1\) (as in our original attempt). We end up sharding our two MLP layers with a single AllReduce synchronisation before the final activation:

Tensor parallelism over W=2 rank applied to two consecutive MLP layers. Here \sigma_0 and \sigma_1 are GeLU and dropout functions respectively. f is the identify and g is an AllReduce. Figure from [14]

Because TP requires an AllReduce per layer (or every two layers), it is almost exclusively used intranode. You can think of it as agglomerating the ranks in a node into a single, much larger rank (over which we can apply other parallelisms).

TP in PyTorch

[WIP]

Context Parallel (CP)

[WIP]

CP in PyTorch

4. Parallelism in Practice

[WIP] Large models use FSDP, gigantic models use all of the above.

1.
A. @. M. Llama Team, The llama 3 herd of models. (2024).
2.
D. P. Kingma & J. Ba, Adam: A method for stochastic optimization. (2017).
3.
P. Micikevicius, S. Narang, J. Alben, G. Diamos, E. Elsen, D. Garcia, B. Ginsburg, M. Houston, O. Kuchaiev, G. Venkatesh, & H. Wu, Mixed precision training. (2018).
4.
S. McCandlish, J. Kaplan, D. Amodei, & O. D. Team, An empirical model of large-batch training. (2018).
5.
S. Li, Y. Zhao, R. Varma, O. Salpekar, P. Noordhuis, T. Li, A. Paszke, J. Smith, B. Vaughan, P. Damania, & S. Chintala, PyTorch distributed: Experiences on accelerating data parallel training. (2020).
6.
S. Rajbhandari, J. Rasley, O. Ruwase, & Y. He, ZeRO: Memory optimizations toward training trillion parameter models. (2020).
7.
Y. Zhao, A. Gu, R. Varma, L. Luo, C.-C. Huang, M. Xu, L. Wright, H. Shojanazeri, M. Ott, S. Shleifer, A. Desmaison, C. Balioglu, P. Damania, B. Nguyen, G. Chauhan, Y. Hao, A. Mathews, & S. Li, PyTorch FSDP: Experiences on scaling fully sharded data parallel. (2023).
8.
Y. Huang, Y. Cheng, A. Bapna, O. Firat, M. X. Chen, D. Chen, H. Lee, J. Ngiam, Q. V. Le, Y. Wu, & Z. Chen, GPipe: Efficient training of giant neural networks using pipeline parallelism. (2019).
9.
T. Chen, B. Xu, C. Zhang, & C. Guestrin, Training deep nets with sublinear memory cost. (2016).
10.
V. A. Korthikanti, J. Casper, S. Lym, L. McAfee, M. Andersch, M. Shoeybi, & B. Catanzaro, Reducing activation recomputation in large transformer models. In D. Song, M. Carbin, & T. Chen,eds., Proceedings of machine learning and systems (Curan, 2023), pp. 341–353.
11.
A. Harlap, D. Narayanan, A. Phanishayee, V. Seshadri, N. Devanur, G. Ganger, & P. Gibbons, PipeDream: Fast and efficient pipeline parallel DNN training. (2018).
12.
D. Narayanan, M. Shoeybi, J. Casper, P. LeGresley, M. Patwary, V. A. Korthikanti, D. Vainbrand, P. Kashinkunti, J. Bernauer, B. Catanzaro, A. Phanishayee, & M. Zaharia, Efficient large-scale language model training on GPU clusters using megatron-LM. (2021).
13.
J. Lamy-Poirier, Breadth-first pipeline parallelism. (2023).
14.
M. Shoeybi, M. Patwary, R. Puri, P. LeGresley, J. Casper, & B. Catanzaro, Megatron-LM: Training multi-billion parameter language models using model parallelism. (2020).

  1. In particular we’re referring to the pre-training phase for LLMs.↩︎

  2. In practice an epoch will loop over an entire training set consisting of several batches (each with their own parameter updates), potentially followed by evaluation on separate validation batches.↩︎

  3. Some tensors in a module don’t have gradients, for example fixed transforms with static parameters.↩︎

  4. FP16 weights and gradients + FP32 master copy of weights + FP32 momentum and variance.↩︎

  5. Floating point addition is not associative, but in practice the difference is small enough to be safely ignored.↩︎

  6. Except for batch-wise operations like BatchNorm, which won’t be locally consistent (unless we use their expensive synchronised implementations like SyncBatchNorm).↩︎

  7. A batch is used to approximate the gradient over the entire training set; for large batches the approximation is already very good, and further increasing batchsize provides negligible improvement (i.e. no longer reducing number of training steps), wasting compute. Moreover, large batchsizes can harm out-of-sample performance by reducing stochasticity.↩︎

  8. Bucket size is user-configurable. Larger buckets lower communication overhead but reduce overlap with compute. Buckets are allocated heuristically during model construction, by the reverse order of model.parameters().↩︎

  9. Note that because the optimizer step will only operate on the sharded parameters, any optimizer that depends on global state over all parameters won’t be locally consistent.↩︎

  10. The actual class is distributed.fsdp.FullyShardedDataParallel. Note that the optimizer should be initialised afterwards, using the sharded module.↩︎

  11. There’s also NO_GRAD_OP which keeps parameters unsharded during the entire forward-backward computation.↩︎

  12. Before forward computation, FSDP replaces the original parameters with views into their unsharded FlatParameter so that autograd behaves correctly. Keeping the original parameters registered requires using the recently added use_orig_params flag.↩︎

  13. e.g. for a transformer model we’ll usually wrap each transformer block, with a final wrapper around the root module sharding the initial embedding and final linear layers.↩︎

  14. We can’t get around this with an extra stream, PyTorch only uses one internal NCCL stream for a given process group.↩︎

  15. A stage has \(\frac{l}{d}\) layers, each of which caches \(O(B)\) of activations.↩︎

  16. Not to be confused with model checkpointing, where we periodically save the entire model to disk, usually at the end of an epoch.↩︎

  17. Recomputing activations will still slow training speed; instead the SoTA is to selectively recompute only some activations [10].↩︎

References

Citation

For attribution, please cite this work as

Mauger (2025, Jan. 1). Bruce Mauger: Distributed Machine Learning: How Does it Work?. Retrieved from brrm.io/posts/distributed-ml/

BibTeX citation

@misc{mauger2025distributed,
  author = {Mauger, Bruce},
  title = {Bruce Mauger: Distributed Machine Learning: How Does it Work?},
  url = {brrm.io/posts/distributed-ml/},
  year = {2025}
}