Spaces:
Running
Running
add TP
Browse files- blog-export.md +161 -135
- dist/bibliography.bib +9 -0
- dist/index.html +355 -5
- src/bibliography.bib +9 -0
- src/index.html +355 -5
blog-export.md
CHANGED
@@ -28,7 +28,7 @@ An overview of the over 4000 experiments across all Llama architectures where ea
|
|
28 |
|
29 |
As you can see, there’s a lot of ground to be covered. Before getting into the trenches of distributed training let’s take a quick high level look on we’ll cover in the post.
|
30 |
|
31 |
-
# TL;DR
|
32 |
|
33 |
This book is very extensive so we decide to start with a very general overview of how you can think about distributed training. At a high level, the key challenge in scaling LLM training is to make a training step (forward/backward/optimizer step) with a large batch size the fastest possible.
|
34 |
|
@@ -55,7 +55,7 @@ But let’s not get too much ahead of our self and scale progressively. To guide
|
|
55 |
|
56 |
Now that we nailed a few key concept and terms let’s get started by revisiting the basic training steps of an LLM!
|
57 |
|
58 |
-
# First Steps: Training on one GPU
|
59 |
|
60 |
Let’s start by quickly reviewing the very basics of model training before we start to scale to many GPUs. When a model is trained on a single GPU, the training typically consists of three steps:
|
61 |
|
@@ -101,7 +101,7 @@ A sweet spot for recent LLM training is typically on the order of 4-60 million t
|
|
101 |
|
102 |
Let’s start by quickly understanding what led to our out-of-memory issue in the first place. This will help us gain some useful intuitions for later.
|
103 |
|
104 |
-
## Memory usage in Transformers
|
105 |
|
106 |
When training a neural network model, one store several items in memory:
|
107 |
|
@@ -143,13 +143,13 @@ For a simple transformer LLM the number of parameters is given by the [following
|
|
143 |
|
144 |
$$
|
145 |
|
146 |
-
N = h*v + L * (12 * h^2 + 13*h) + 2*h
|
147 |
$$
|
148 |
|
149 |
> Note: we excluded the positional embedding count as rotary embeddings are not learned.
|
150 |
>
|
151 |
|
152 |
-
In that equation, $h$ is the hidden dimension, $v$ the vocabulary size, and $L$ the number of layers in the model.
|
153 |
|
154 |
Memory requirements for the parameters and gradients are simply the number of parameters multiplied by the number of bytes per parameter. In good old-fashioned full precision (FP32) training both parameters and gradients require 4 bytes while the optimizer, if we use Adam, requires the momentum and variance to be stored, which adds another two 4 bytes per parameter. In summary:
|
155 |
|
@@ -173,7 +173,7 @@ m_{params\_fp32} = 4 * N \\
|
|
173 |
m_{opt} = (4+4) * N
|
174 |
$$
|
175 |
|
176 |
-
> Some
|
177 |
>
|
178 |
|
179 |
Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as having the model which does the forward/backward in half precision it allows us to (1) use optimized lower precision operations on the GPU which are faster and (2) reduces the activation memory requirements during the forward pass.
|
@@ -203,7 +203,7 @@ m_{act} = L* seq * bs * h * (34 + \frac{5*n_{heads}*seq}{h})
|
|
203 |
|
204 |
$$
|
205 |
|
206 |
-
Here L is the number of layers, $seq$ the sequence length, $bs$ the batch size in samples, $h$ the hidden dimension of the model and $n_{heads}$ the number of heads.
|
207 |
|
208 |
For the exact numbers derivation, you can follow this [NVIDIA pape](https://arxiv.org/pdf/2205.05198)r on recomputation, it essentially requires you to do some accounting of all the sizes of intermediate activations between each operation.
|
209 |
|
@@ -219,7 +219,7 @@ Is there a way to tame this “activation explosion”? Good question, reader!
|
|
219 |
|
220 |
It’s time to explain our first technique – called ***activation recomputation**–* ****which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.
|
221 |
|
222 |
-
## **Activation recomputation**
|
223 |
|
224 |
The general idea behind ***activation recomputation** –*also called ***gradient checkpointing*** or ***rematerialization**– *****is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:
|
225 |
|
@@ -252,7 +252,7 @@ Now that we’ve learned about recomputation, we can tame the activations memory
|
|
252 |
|
253 |
However, activations still bears a linear dependance on the batch size and all our profiles in the barplots above were using `bs=1` so as we move to larger batch sizes it might become an issue again. Do not despair as we have a second tool in our box - ***gradient accumulation*** to the rescue!
|
254 |
|
255 |
-
## Gradient accumulation
|
256 |
|
257 |
Now that we’ve used activation recomputation to fit our model with a small batch size on a single GPU, we still need to reach our target batch size, let’s say 1M tokens (see our earlier discussion on optimal batch size). Gradient accumulation is a very straightforward method to avoid memory explosion when doing this.
|
258 |
|
@@ -281,9 +281,9 @@ But if you’ve carefully followed, you probably noticed that the forward/backwa
|
|
281 |
|
282 |
Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called ***data parallelism** which is just a parallel version of gradient accumulation*.
|
283 |
|
284 |
-
TODO: intro for
|
285 |
|
286 |
-
## torch.profiler
|
287 |
|
288 |

|
289 |
|
@@ -297,7 +297,7 @@ In this naive approach we see a long AllReduce operation (stream 28) happening t
|
|
297 |
|
298 |
**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**
|
299 |
|
300 |
-
# Data Parallelism
|
301 |
|
302 |
The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism.
|
303 |
|
@@ -310,11 +310,6 @@ This involves our first “distributed communication” primitive: [**All-Reduce
|
|
310 |
> If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].
|
311 |
>
|
312 |
|
313 |
-
TODO: bucket grads to avoid multiple comms
|
314 |
-
TODO: show comms overlap
|
315 |
-
|
316 |
-
TODO: any comms requires at least a contiguous buffer to do comms → TIP: make sure tensors that’ll be communicated are contiguous in memory to avoid redundant memory copies
|
317 |
-
|
318 |
TODO: embed naive DP: [https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60](https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60)
|
319 |
|
320 |
TODO: embed bucket DP: [https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171](https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171)
|
@@ -327,7 +322,7 @@ Instead we should try to overlap communication and computation whenever possible
|
|
327 |
|
328 |
Let’s see three optimizations that are done in practice for this!
|
329 |
|
330 |
-
### **First optimization:** Overlap gradient synchronization with backward pass
|
331 |
|
332 |
The main drawback of the naive DDP approach we’ve just described is that after the backward pass (*computation*), we have to wait for gradient synchronization (*communication*) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!
|
333 |
|
@@ -352,7 +347,7 @@ Overlapping computation and communication reduces the time spent waiting for gra
|
|
352 |
|
353 |
This is our first example of “*overlapping computation and communication*” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency.
|
354 |
|
355 |
-
### **Second optimization:** Bucketing gradients
|
356 |
|
357 |
But we can even go further. For a given number of parameters to synchronize, GPU operations like collective communications are often more efficient when performing few calls on large tensors rather than many calls on smaller tensors. Therefore, instead of performing independent all-reduce for each gradient, we can group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket. Think of it like packing items into boxes before shipping—it's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.
|
358 |
|
@@ -362,7 +357,7 @@ The selected bucket size will be a key factor in determining the efficiency of D
|
|
362 |
|
363 |
[TODO: benchmark all reduce with different size / bucket size results ?]
|
364 |
|
365 |
-
### **Third optimization: I**nterplay with gradient accumulation
|
366 |
|
367 |
As we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with `optimizer.step()`. When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.
|
368 |
|
@@ -370,7 +365,11 @@ In a naive version, an all-reduce operation is automatically triggered after eac
|
|
370 |
|
371 |
In PyTorch, this is typically solved by adding a [`model.no_sync()`](https://github.com/pytorch/pytorch/blob/5ea67778619c31b13644914deef709199052ee55/torch/nn/parallel/distributed.py#L1408-L1435) decorator, which disables gradient synchronization, on the backward passes which don’t need reduction.
|
372 |
|
373 |
-
|
|
|
|
|
|
|
|
|
374 |
|
375 |
Let’s update our batch size equation with our newly learned Data Parallelism and Gradient Accumulation parameters:
|
376 |
|
@@ -387,7 +386,7 @@ Given a targeted global batch size, we can thus trade gradient accumulation step
|
|
387 |
|
388 |
Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 3 more dimensions).
|
389 |
|
390 |
-
## Our journey up to now
|
391 |
|
392 |
Let’s quickly summarize what we’ve seen up to now and how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:
|
393 |
|
@@ -406,26 +405,31 @@ If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs
|
|
406 |
|
407 |
Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!
|
408 |
|
409 |
-
> Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by ring latency which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
|
410 |
>
|
411 |
|
412 |
TODO: We’re gaining overall throughput but losing efficiency as we scale DP too much
|
413 |
|
414 |

|
415 |
|
416 |
-
|
|
|
|
|
417 |
|
418 |
-
|
419 |
|
420 |
-
This is not always the case! As we
|
421 |
|
422 |

|
423 |
|
|
|
|
|
|
|
424 |
Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some of these tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices!
|
425 |
|
426 |
There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined! The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!
|
427 |
|
428 |
-
## ZeRO (**Ze**ro **R**edundancy **O**ptimizer)
|
429 |
|
430 |
In this section we will introduce DeepSpeed ZeRO (**Ze**ro **R**edundancy **O**ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.
|
431 |
|
@@ -447,7 +451,7 @@ ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer st
|
|
447 |
|
448 |
Let’s have a closer look how much we can save with the partitioning of each ZeRO stage!
|
449 |
|
450 |
-
### Memory usage revisited
|
451 |
|
452 |
Let’s first recap the memory usage of optimizer states, gradients, and parameters during a standard training. Let’s define the number of our model's parameters as $\Psi$ (previously N but here we use the original ZeRO notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:
|
453 |
|
@@ -466,7 +470,7 @@ Memory consumption of DP and three stages of Zero-DP. $\Psi$ denotes number of p
|
|
466 |
|
467 |
Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.
|
468 |
|
469 |
-
### ZeRO-1: Partitioning Optimizer States
|
470 |
|
471 |
In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?
|
472 |
|
@@ -478,24 +482,26 @@ This explains the memory formula of $2\Psi + 2\Psi + \frac{k\Psi}{N_d}$ that we
|
|
478 |
|
479 |
- Forward pass with all bf16 parameters (but different microbatches across DP ranks)
|
480 |
- Backward pass with all gradients (but different microbatches across DP ranks)
|
481 |
-
- Perform an reduce
|
482 |
- Each replica perform an optimizer step (has only 1/$N_d$ optimizer states) updates only on 1/$N_d$ of fp32 parameters, and then 1/$N_d$ of bf16 parameters
|
483 |
- [New operation in ZeRO, not in vanilla DP] Perform an all-gather of bf16 parameters to send missing slices back to each replica
|
484 |
|
485 |

|
486 |
|
|
|
|
|
487 |
If you've been following along, you'll recall from vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of bf16 parameters. There are two main strategies for this:
|
488 |
|
489 |
-
1.
|
490 |
-
2.
|
491 |
|
492 |
But unfortunately these techniques are not as evident to implement as they seem and require sophisticated use of hooks / bucketing. In practice we can just use Zero3 / FSDP implementation where the FSDPUnit is the entire model, more details about this later..
|
493 |
|
494 |
-
### ZeRO-2: Adding **Gradient Partitioning**
|
495 |
|
496 |
In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates $\frac{1}{N_d}$ of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place since only a subset is needed for the optimization step.
|
497 |
|
498 |
-
→ During the backward pass, instead of performing an all-reduce over the gradients, we
|
499 |
|
500 |
> In case of FP32 gradient accumulation, we only need to keep $\frac{1}{N_d}$ fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the $\frac{1}{N_d}$ fp32_grads.
|
501 |
>
|
@@ -504,14 +510,14 @@ In ZeRO-1 the optimizer states have been partitioned, which means that each repl
|
|
504 |
|
505 |
It’s easy to see now that sharding the gradients leads to to $2\Psi + \frac{2\Psi+k\Psi}{N_d}$ and as $N_d$ is increased we can save up to 8x memory over the baseline. In terms of communication the same process applies as for ZeRO-1, with the only difference that we communicate and release on the fly. In total, ZeRO-2 is thus also equivalent to vanilla DP training w.r.t. communication.
|
506 |
|
507 |
-
 \cdot peak_{flops}}{2 \cdot seq \cdot mbs \cdot peak_{bw}}
|
537 |
$$
|
538 |
|
|
|
|
|
|
|
|
|
|
|
539 |
Overall it may sound like we significantly increase communication overhead, but thanks to **prefetching** we can start all-gathering weights for Layer n+1 while we do the current forward for Layer n which usually overlaps communication and computation as long as we don’t scale DP too much (as a rule of thumb: DP<512).
|
540 |
|
541 |
In terms of memory we can see that our equation now reached it’s final form of $\frac{2\Psi +2\Psi+k\Psi}{N_d}$ which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t specifically help with the intermediate activations that we discussed in the previous chapter. ZeRO is an orthogonal technique to the activation checkpointing and gradient accumulation we discussed in other chapters.
|
@@ -547,7 +558,7 @@ In terms of memory we can see that our equation now reached it’s final form of
|
|
547 |
|
548 |
However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only reduce the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with e.g. short sequence length.
|
549 |
|
550 |
-
 multiplying each colu
|
|
582 |
|
583 |
In practice a small example of the operation looks like this:
|
584 |
|
585 |
-
: We'll copy the complete input matrices to each worker, requiring an operation called [***broadcast***](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21), and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an [***all-gather](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21)*** operation*.*
|
590 |
|
591 |
-
: As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, which needs a ***scatter*** operation rather than a broadcast as used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, thus requiring an all-reduce operation in this scenario.
|
594 |
|
595 |
We see here our fourth distributed primitive: ***s[catter](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21)***!
|
596 |
|
597 |
-
.
|
608 |
|
@@ -610,25 +621,29 @@ We can generally follow a similar approach where Q, K, and V matrices are split
|
|
610 |
|
611 |
It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.
|
612 |
|
613 |
-
, we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This *exposed communication* overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied.
|
620 |
|
621 |
-
|
|
|
|
|
622 |
|
623 |
Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.
|
624 |
|
625 |
In practice, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. As shown in the throughput plot above, we observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. This illustrates how communication costs can dominate at higher degrees of parallelism.
|
626 |
|
627 |
-
However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients
|
628 |
|
629 |
-
, we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.
|
683 |
|
@@ -733,41 +748,35 @@ And for the embedding layer
|
|
733 |
s: unchanged | h: full (weight_out is full + **reduce-scatter** for correctness)
|
734 |
s: **reduce-scatter** to sharded |
|
735 |
|
736 |
-
Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).
|
737 |
-
|
738 |
You can find an example of implementation of both column and row linear TP in picotron:
|
739 |
[https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py](https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py)
|
740 |
|
741 |
-
|
742 |
-
|
743 |
-
If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops **IN EACH LAYER** (2 for Attention and 2 for MLP), as shown here for the MLP region:
|
744 |
|
745 |

|
746 |
|
747 |
-
|
748 |
|
749 |
-
|
750 |
-
[https://github.com/huggingface/nanotron/blob/9055c664c28a3b430b4e53bfcb5a074068c90f2a/src/nanotron/parallel/tensor_parallel/functional.py#L169-L262](https://github.com/huggingface/nanotron/blob/9055c664c28a3b430b4e53bfcb5a074068c90f2a/src/nanotron/parallel/tensor_parallel/functional.py#L169-L262)
|
751 |
-
and you can find more tricks [here](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487).
|
752 |
-
>
|
753 |
|
754 |
-
|
755 |
|
756 |
-
|
757 |
|
758 |
-
|
759 |
|
760 |
-
-
|
761 |
|
762 |
-
|
763 |
|
764 |
-
|
|
|
765 |
|
766 |
-
|
767 |
|
768 |
-
 on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.
|
771 |
|
772 |
Let’s summarize our observations:
|
773 |
|
@@ -775,8 +784,6 @@ Let’s summarize our observations:
|
|
775 |
- the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone
|
776 |
- the Torch memory fragmentation makes it hard for us to predict the exact peak reserved memory. For more details check memory_viz tool section. [TODO: add link]
|
777 |
|
778 |
-
TODO (outro): TP can help sharding activs (sometimes on hidden_dim, sometimes on seq_dim) by sharding the big linears across ranks, but what if we want to scale sequence_length, our activs will still blow up in TP region. → Context parallelism
|
779 |
-
|
780 |
**We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.**
|
781 |
|
782 |
However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.
|
@@ -791,7 +798,7 @@ With Tensor Parallelism and Sequence Parallelism, we can reduce the memory requi
|
|
791 |
|
792 |
Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:
|
793 |
|
794 |
-
 and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.
|
844 |
|
845 |
-
 or we gather them one-by-one from each GPU to each GPU as needed:
|
850 |
|
851 |
-
 and deepspeed(All2All) implementations
|
|
864 |
|
865 |
In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:
|
866 |
|
867 |
-
 and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.
|
870 |
|
871 |
Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.
|
872 |
|
873 |
-
** as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:
|
965 |
|
966 |
-
 or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in details in [https://arxiv.org/abs/2211.05953](https://arxiv.org/pdf/2211.05953).
|
1073 |
|
1074 |
You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.
|
1075 |
|
1076 |
-
 work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):
|
1085 |
|
1086 |
-
 schedule with zero bubble taking advantage for this fine-grained decomposition.
|
1093 |
|
1094 |
DeepSeek’s DualPipe propose an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph
|
1095 |
|
1096 |
-
 paper for a discussion of the heuristics and algorithms to perform such a scheduling.
|
1099 |
|
@@ -1105,7 +1117,7 @@ Mixture-of-expert models have gained some traction with models such as Mixtral o
|
|
1105 |
|
1106 |
So whereas Context parallelism
|
1107 |
|
1108 |
-
](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%
|
1109 |
|
1110 |
[https://arxiv.org/pdf/2407.06204](https://arxiv.org/pdf/2407.06204)
|
1111 |
|
@@ -1145,7 +1157,7 @@ Combining ZeRO-3 and TP doesn’t raise any specific issues except how to organi
|
|
1145 |
|
1146 |
# How to Find the Best Training Configuration
|
1147 |
|
1148 |
-
. However, in scenarios with long contexts, the primary memory usage will tend to shifts from model weights, gradients, and optimizer states to activation values. In such cases, context parallelism becomes more beneficial than pipeline parallelism. Note that this is not an exact recipe and you should think of this more as a starting point of hyperparameters to run your own benchmarks. For instance sometimes TP mixed with PP can be more efficient, even if TP<8 and ZeRO-1/2 can make sense to mix in with 4D parallelism as well.
|
1173 |
|
@@ -1191,13 +1203,13 @@ Generally, GPUs have a very hierarchical organization. In this primer we’ll ke
|
|
1191 |
|
1192 |
On the compute side, GPUs consist of an array of compute units called **Streaming Multiprocessors** (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see [https://resources.nvidia.com/en-us-tensor-core](https://resources.nvidia.com/en-us-tensor-core) for details), each capable of handling multiple threads simultaneously.
|
1193 |
|
1194 |
-
.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%
|
1195 |
|
1196 |
Original figure from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing).
|
1197 |
|
1198 |
The memory side is also highly hierarchical with several layers of cache and memory: **Registers** are the smallest units and are private to the threads during executions, **Shared Memory** and **L1 cache are** shared between the threads running on a single SM, higher up is the **L2 cache** shared by all SMs, finally there is the **Global Memory** which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.
|
1199 |
|
1200 |
-
](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%
|
1201 |
|
1202 |
Original figure from [https://www.youtube.com/watch?v=ZQKMZIP3Fzg](https://www.youtube.com/watch?v=ZQKMZIP3Fzg)
|
1203 |
|
@@ -1207,11 +1219,11 @@ A piece of code running on a core of the GPU is called a **kernel**. It can be w
|
|
1207 |
|
1208 |
To run the kernel, you will also need a specific code part (called **host code**) which is executed on the **CPU**/host and will take care of preparing data allocations and loading data and code.
|
1209 |
|
1210 |
-

|
1213 |
|
1214 |
-

|
1217 |
|
@@ -1250,7 +1262,7 @@ def elu(x, alpha=1.0):
|
|
1250 |
|
1251 |
The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns) :
|
1252 |
|
1253 |
-
` provides a unique block ID, that we use to determine wh
|
|
1309 |
|
1310 |
When we benchmark the generated kernel using `triton.testing.Benchmark` we have the following performance :
|
1311 |
|
1312 |
-
 :
|
1351 |
|
1352 |
-
` and `(1, 0)` (which will end up in the same warp) will both load from the same column of matrix `B` but different rows of matrix `A`. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with `i = 0`, thread `(0, 0)` will load $A_{0,0}$, and thread `(1, 0)` will load $A_{1,0}$. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.
|
1361 |
|
1362 |
-
 and a tile of matrix B (of size `BLOCK_SIZE_K` by `BLOCK_SIZE_N`). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.
|
1396 |
|
1397 |
-
](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%
|
1398 |
|
1399 |
From [https://cnugteren.github.io/tutorial/pages/page4.html](https://cnugteren.github.io/tutorial/pages/page4.html)
|
1400 |
|
@@ -1438,7 +1450,7 @@ When benchmarking this kernel using ncu, we noticed that the memory throughput i
|
|
1438 |
|
1439 |
The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:
|
1440 |
|
1441 |
-
, specifically in the **Warp Stall Reasons** section. There we can read that :
|
1444 |
|
@@ -1463,13 +1475,13 @@ Flash attention is a technique pioneered by [Tri Dao](https://tridao.me) that op
|
|
1463 |
|
1464 |
A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:
|
1465 |
|
1466 |
-
, the attention matrix.
|
1471 |
|
1472 |
-
)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%
|
1473 |
|
1474 |
From the FLASH-ATTENTION paper ([https://arxiv.org/pdf/2205.14135](https://arxiv.org/pdf/2205.14135))
|
1475 |
|
@@ -1527,13 +1539,13 @@ The principle of floating point numbers can be easily illustrated by recalling t
|
|
1527 |
|
1528 |
Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:
|
1529 |
|
1530 |
-
, [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8#torchaofloat8), and [DeepSeek-V3](https://arxiv.org/abs/2412.19437) - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.
|
1565 |
|
1566 |
-
 observed, instability increases as learning rates rise for a fixed model size, making FP8 pretraining particularly tricky.
|
1569 |
|
@@ -1701,7 +1713,7 @@ Throughout this blogpost we’ll scale LLM training from one to hundreds of GPUs
|
|
1701 |
|
1702 |
The general setup is that we have a number of independent nodes which could be CPU cores, GPUs, or compute nodes. Each performs some computation and then we want to communicate the result or parts of it to the other nodes for the next computation step (t+1).
|
1703 |
|
1704 |
-
, it determines how many workers (aka nodes) exists and assigns a rank to each one (which we can get with `dist.get_rank`). Finally, it establishes a connection between the workers.
|
1715 |
|
@@ -1754,7 +1766,7 @@ Great, seems like it works as expected. Note that the rank messages can be print
|
|
1754 |
|
1755 |
Reduce patterns are among the most fundamental patterns in distributed data processing. The idea is that you want to combine the data present on each node through a function `f()` which can be for instance summation or averaging. In the Reduce paradigm the result is sent to the root node only, whereas in the AllReduce case the result is broadcasted to all nodes:
|
1756 |
|
1757 |
-
 or gather all data on all nodes (in the case of AllGather). A picture being worth 1000 words, let’s take a look:
|
1851 |
|
1852 |
-
.
|
1855 |
|
@@ -1923,7 +1935,7 @@ As the name subtly suggests, the goal of the Scatter operation is to take data o
|
|
1923 |
|
1924 |
The ReduceScatter pattern is slightly more complex: imagine you apply an operation like in the Reduce case but instead of moving the result to just one node we also distribute it evenly to all nodes:
|
1925 |
|
1926 |
-
.table(sort_by="cuda_time_total", row_limit=8))
|
|
2111 |
|
2112 |
This would print aggregated profiling results sorted by the total CUDA time, and the output would be:
|
2113 |
|
2114 |
-
 with `aten::layer_norm`, progressing to `aten::native_layer_norm`, and then transitioning to `cudaLaunchKernel`. From there, we move on to the GPU, where the `vectorized_layer_norm_kernel` kernel is called.
|
2126 |
|
@@ -2141,7 +2153,7 @@ ncu --set full -o output python layer_norm.py
|
|
2141 |
|
2142 |
and open the file `output.ncu-rep` with Nsight Compute, you will have a view that looks like this :
|
2143 |
|
2144 |
-
 depends directly on the output (Y). This equation is telling us that to get the gradient of the loss with respect to our input (dL/dX), we multiply the gradient of the loss with respect to the output (dL/dY) by our weight matrix (W).
|
2225 |
|
2226 |
-

|
2311 |
```
|
2312 |
|
2313 |
-
:
|
|
2461 |
|
2462 |
### Interconnect
|
2463 |
|
2464 |
-
** created by PyTorch.
|
2499 |
|
2500 |
-
, since each head can operate independently from others, we can apply ring attention within each TP rank](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%
|
2627 |
|
2628 |
TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
|
2629 |
TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank
|
@@ -2636,7 +2648,7 @@ In fact, given an activation value of shape$[ \text{batch\_size}, \text{sequence
|
|
2636 |
|
2637 |
However, through extensive experimentation, we identified two effective training recipes that allowed us to **fully pretrain a 1B LLaMA model in FP8**, covering both the forward and backward passes, while using an FP8 optimizer. More importantly, our approach successfully matched LLaMA-2’s pretraining learning rate. The result?
|
2638 |
|
2639 |
-
. We successfully tested this to train a 1B LLaMA up to 100B tokens and a 7B LLaMA up to 25B tokens.
|
2642 |
|
@@ -2674,14 +2686,28 @@ Let’s take a moment to look better at this fundamental tool for distributed tr
|
|
2674 |
|
2675 |
**Non-overlapping:** If we don't overlap the communication and computation, each computation (represented by the purple block) can only begin after the communication (green block) is complete and total time is the sum of communication and computation.
|
2676 |
|
2677 |
-
 is launched immediately, one after the other. In this case the total time is *only* the sum of computations.
|
2680 |
|
2681 |
-
 which the time has come to explore now.
|
2686 |
|
2687 |
-
[TODO: comment from Nouamane on comms overlapping with DP 512]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
As you can see, there’s a lot of ground to be covered. Before getting into the trenches of distributed training let’s take a quick high level look on we’ll cover in the post.
|
30 |
|
31 |
+
# ✅ TL;DR
|
32 |
|
33 |
This book is very extensive so we decide to start with a very general overview of how you can think about distributed training. At a high level, the key challenge in scaling LLM training is to make a training step (forward/backward/optimizer step) with a large batch size the fastest possible.
|
34 |
|
|
|
55 |
|
56 |
Now that we nailed a few key concept and terms let’s get started by revisiting the basic training steps of an LLM!
|
57 |
|
58 |
+
# ✅ First Steps: Training on one GPU
|
59 |
|
60 |
Let’s start by quickly reviewing the very basics of model training before we start to scale to many GPUs. When a model is trained on a single GPU, the training typically consists of three steps:
|
61 |
|
|
|
101 |
|
102 |
Let’s start by quickly understanding what led to our out-of-memory issue in the first place. This will help us gain some useful intuitions for later.
|
103 |
|
104 |
+
## ✅ Memory usage in Transformers
|
105 |
|
106 |
When training a neural network model, one store several items in memory:
|
107 |
|
|
|
143 |
|
144 |
$$
|
145 |
|
146 |
+
N = 2*h*v + L * (12 * h^2 + 13*h) + 2*h
|
147 |
$$
|
148 |
|
149 |
> Note: we excluded the positional embedding count as rotary embeddings are not learned.
|
150 |
>
|
151 |
|
152 |
+
In that equation, $h$ is the hidden dimension, $v$ the vocabulary size, and $L$ the number of layers in the model. The first term is the parameter count for the word embedding and LM head. When they are tied (meaning we use the same parameters for both), it would be 1. This is beneficial for small models, as the vocabulary size is generally much larger than the hidden dimension. We don’t want the number of parameters in an LLM to be dominated by the embedding layer.
|
153 |
|
154 |
Memory requirements for the parameters and gradients are simply the number of parameters multiplied by the number of bytes per parameter. In good old-fashioned full precision (FP32) training both parameters and gradients require 4 bytes while the optimizer, if we use Adam, requires the momentum and variance to be stored, which adds another two 4 bytes per parameter. In summary:
|
155 |
|
|
|
173 |
m_{opt} = (4+4) * N
|
174 |
$$
|
175 |
|
176 |
+
> Some libraries store grads in fp32 which would require an additional $m_{params\_fp32} = 4 * N$ memory. This is done for example in nanotron, because `bf16` is lossy for smaller values and we always prioritize stability. See https://github.com/microsoft/DeepSpeed/issues/1773 for more information.
|
177 |
>
|
178 |
|
179 |
Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as having the model which does the forward/backward in half precision it allows us to (1) use optimized lower precision operations on the GPU which are faster and (2) reduces the activation memory requirements during the forward pass.
|
|
|
203 |
|
204 |
$$
|
205 |
|
206 |
+
Here $L$ is the number of layers, $seq$ the sequence length, $bs$ the batch size in samples, $h$ the hidden dimension of the model and $n_{heads}$ the number of heads.
|
207 |
|
208 |
For the exact numbers derivation, you can follow this [NVIDIA pape](https://arxiv.org/pdf/2205.05198)r on recomputation, it essentially requires you to do some accounting of all the sizes of intermediate activations between each operation.
|
209 |
|
|
|
219 |
|
220 |
It’s time to explain our first technique – called ***activation recomputation**–* ****which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.
|
221 |
|
222 |
+
## ✅ **Activation recomputation**
|
223 |
|
224 |
The general idea behind ***activation recomputation** –*also called ***gradient checkpointing*** or ***rematerialization**– *****is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:
|
225 |
|
|
|
252 |
|
253 |
However, activations still bears a linear dependance on the batch size and all our profiles in the barplots above were using `bs=1` so as we move to larger batch sizes it might become an issue again. Do not despair as we have a second tool in our box - ***gradient accumulation*** to the rescue!
|
254 |
|
255 |
+
## ✅ Gradient accumulation
|
256 |
|
257 |
Now that we’ve used activation recomputation to fit our model with a small batch size on a single GPU, we still need to reach our target batch size, let’s say 1M tokens (see our earlier discussion on optimal batch size). Gradient accumulation is a very straightforward method to avoid memory explosion when doing this.
|
258 |
|
|
|
281 |
|
282 |
Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called ***data parallelism** which is just a parallel version of gradient accumulation*.
|
283 |
|
284 |
+
TODO: intro for torch.profiler section
|
285 |
|
286 |
+
## ✅ torch.profiler
|
287 |
|
288 |

|
289 |
|
|
|
297 |
|
298 |
**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**
|
299 |
|
300 |
+
# 🚧 Data Parallelism
|
301 |
|
302 |
The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism.
|
303 |
|
|
|
310 |
> If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].
|
311 |
>
|
312 |
|
|
|
|
|
|
|
|
|
|
|
313 |
TODO: embed naive DP: [https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60](https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60)
|
314 |
|
315 |
TODO: embed bucket DP: [https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171](https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171)
|
|
|
322 |
|
323 |
Let’s see three optimizations that are done in practice for this!
|
324 |
|
325 |
+
### 🚧 **First optimization:** Overlap gradient synchronization with backward pass
|
326 |
|
327 |
The main drawback of the naive DDP approach we’ve just described is that after the backward pass (*computation*), we have to wait for gradient synchronization (*communication*) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!
|
328 |
|
|
|
347 |
|
348 |
This is our first example of “*overlapping computation and communication*” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency.
|
349 |
|
350 |
+
### 🚧 **Second optimization:** Bucketing gradients
|
351 |
|
352 |
But we can even go further. For a given number of parameters to synchronize, GPU operations like collective communications are often more efficient when performing few calls on large tensors rather than many calls on smaller tensors. Therefore, instead of performing independent all-reduce for each gradient, we can group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket. Think of it like packing items into boxes before shipping—it's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.
|
353 |
|
|
|
357 |
|
358 |
[TODO: benchmark all reduce with different size / bucket size results ?]
|
359 |
|
360 |
+
### 🚧 **Third optimization: I**nterplay with gradient accumulation
|
361 |
|
362 |
As we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with `optimizer.step()`. When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.
|
363 |
|
|
|
365 |
|
366 |
In PyTorch, this is typically solved by adding a [`model.no_sync()`](https://github.com/pytorch/pytorch/blob/5ea67778619c31b13644914deef709199052ee55/torch/nn/parallel/distributed.py#L1408-L1435) decorator, which disables gradient synchronization, on the backward passes which don’t need reduction.
|
367 |
|
368 |
+
> When performing communication operations, tensors must be contiguous in memory. To avoid redundant memory copies during communication, ensure that tensors that will be communicated are stored contiguously in memory.
|
369 |
+
Sometimes we need to allocate additional continuous buffers of the size of activations or model parameters specifically for communication, which contributes to the peak memory usage during training.
|
370 |
+
>
|
371 |
+
|
372 |
+
## 🚧 Revisit global batch size
|
373 |
|
374 |
Let’s update our batch size equation with our newly learned Data Parallelism and Gradient Accumulation parameters:
|
375 |
|
|
|
386 |
|
387 |
Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 3 more dimensions).
|
388 |
|
389 |
+
## 🚧 Our journey up to now
|
390 |
|
391 |
Let’s quickly summarize what we’ve seen up to now and how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:
|
392 |
|
|
|
405 |
|
406 |
Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!
|
407 |
|
408 |
+
> Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by *ring latency* (time required for a signal to propagate once around the ring) **which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
|
409 |
>
|
410 |
|
411 |
TODO: We’re gaining overall throughput but losing efficiency as we scale DP too much
|
412 |
|
413 |

|
414 |
|
415 |
+
$$
|
416 |
+
t_{comm}/t_{compute} = \frac{\text{num\_params}}{\text{2 num\_tokens}} \cdot \left(\frac{DP-1}{DP}\right) \cdot \frac{\text{peak\_flops}}{\text{peak\_bw}} \leq 1
|
417 |
+
$$
|
418 |
|
419 |
+
**We’ve explored data parallelism, our first (simple) strategy to scale training across more GPUs. It works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!**
|
420 |
|
421 |
+
The keen reader have already probably notes however that this assumes that we can fit at least one input sample forward pass (mbs*=1)* into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated.
|
422 |
|
423 |

|
424 |
|
425 |
+
> Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)
|
426 |
+
>
|
427 |
+
|
428 |
Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some of these tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices!
|
429 |
|
430 |
There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined! The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!
|
431 |
|
432 |
+
## 🚧 ZeRO (**Ze**ro **R**edundancy **O**ptimizer)
|
433 |
|
434 |
In this section we will introduce DeepSpeed ZeRO (**Ze**ro **R**edundancy **O**ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.
|
435 |
|
|
|
451 |
|
452 |
Let’s have a closer look how much we can save with the partitioning of each ZeRO stage!
|
453 |
|
454 |
+
### 🚧 Memory usage revisited
|
455 |
|
456 |
Let’s first recap the memory usage of optimizer states, gradients, and parameters during a standard training. Let’s define the number of our model's parameters as $\Psi$ (previously N but here we use the original ZeRO notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:
|
457 |
|
|
|
470 |
|
471 |
Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.
|
472 |
|
473 |
+
### 🚧 ZeRO-1: Partitioning Optimizer States
|
474 |
|
475 |
In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?
|
476 |
|
|
|
482 |
|
483 |
- Forward pass with all bf16 parameters (but different microbatches across DP ranks)
|
484 |
- Backward pass with all gradients (but different microbatches across DP ranks)
|
485 |
+
- Perform an reduce-scatter **[ADD link!]** on the gradients (reduce-scatter is 2 times faster than all reduce! *yay, a third communication primitive!*)
|
486 |
- Each replica perform an optimizer step (has only 1/$N_d$ optimizer states) updates only on 1/$N_d$ of fp32 parameters, and then 1/$N_d$ of bf16 parameters
|
487 |
- [New operation in ZeRO, not in vanilla DP] Perform an all-gather of bf16 parameters to send missing slices back to each replica
|
488 |
|
489 |

|
490 |
|
491 |
+

|
492 |
+
|
493 |
If you've been following along, you'll recall from vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of bf16 parameters. There are two main strategies for this:
|
494 |
|
495 |
+
1. During optimizer step: We can initiate the all-gather immediately after the optimizer updates part of the parameters. This allows the communication to potentially overlap with other parameters update.
|
496 |
+
2. During forward: We can overlap the all-gather of each layer’s parameters with the forward pass.
|
497 |
|
498 |
But unfortunately these techniques are not as evident to implement as they seem and require sophisticated use of hooks / bucketing. In practice we can just use Zero3 / FSDP implementation where the FSDPUnit is the entire model, more details about this later..
|
499 |
|
500 |
+
### 🚧 ZeRO-2: Adding **Gradient Partitioning**
|
501 |
|
502 |
In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates $\frac{1}{N_d}$ of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place since only a subset is needed for the optimization step.
|
503 |
|
504 |
+
→ During the backward pass, instead of performing an all-reduce over the gradients, we only perform a ***reduce-scatter*** operation! **Where we only spread the $\frac{1}{N_d}$ gradients needed in memory, thus saving more memory compared to ZeRO-1
|
505 |
|
506 |
> In case of FP32 gradient accumulation, we only need to keep $\frac{1}{N_d}$ fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the $\frac{1}{N_d}$ fp32_grads.
|
507 |
>
|
|
|
510 |
|
511 |
It’s easy to see now that sharding the gradients leads to to $2\Psi + \frac{2\Psi+k\Psi}{N_d}$ and as $N_d$ is increased we can save up to 8x memory over the baseline. In terms of communication the same process applies as for ZeRO-1, with the only difference that we communicate and release on the fly. In total, ZeRO-2 is thus also equivalent to vanilla DP training w.r.t. communication.
|
512 |
|
513 |
+

|
514 |
|
515 |
> Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option. The reason some distributed training frameworks don’t support it is that gradient sharding may interfere with and make more complex other parallel strategies we discussed later.
|
516 |
>
|
517 |
|
518 |
Now that we’ve sharded gradients as well, we are we done? Or can we keep getting away with this? Well, sort of. We would like to reduce the memory of the parameters as well, and we’ve seen that we don’t need to wait for the entire all-gather to start the forward, we can already start the forward once we get the first layer.. here comes ZeRO-3!
|
519 |
|
520 |
+
### 🚧 ZeRO-3: Adding Parameter **Partitioning**
|
521 |
|
522 |
For Stage 3 we extend the above approach of sharding tensors over DP replicas up to sharding the model’s parameters.
|
523 |
|
|
|
526 |
|
527 |
So how do we do a forward or backward pass in practice if all parts of the model are distributed? Quite simply we gather them on-demand when we need them. In the forward pass this looks as follows:
|
528 |
|
529 |
+

|
530 |
|
531 |
So as we perform the forward pass and sequentially go through the layers we retrieve the necessary parameters on demand and immediately flush them from memory when we don’t need them anymore. The backward pass works the same way just inverted in flow and we produce the gradient shards:
|
532 |
|
533 |
+

|
534 |
|
535 |
During the forward pass we do all-gather operations for the parameters when we need them, so a $\Psi$ communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another $\Psi$ in communication tax. Finally we need the same ***reduce-scatter*** as in ZeRO-2 for the gradients which costs also $\Psi$ in communication and we arrive at a total communication cost of $3\Psi$, compared to $2\Psi$ for Zero-2.
|
536 |
|
537 |
The other issue is that we need to do these all-gathers continuously throughout the forward and backward step, which amounts to `2 * num_layers - 1` additional all-gathers in a training step compared to Zero-2 as we can see in the following figure:
|
538 |
|
539 |
+

|
540 |
|
541 |
$$
|
542 |
\frac{t_{comm}}{t_{compute}} = \frac{(DP-1) \cdot peak_{flops}}{2 \cdot seq \cdot mbs \cdot peak_{bw}}
|
543 |
$$
|
544 |
|
545 |
+
$$
|
546 |
+
|
547 |
+
t_{comm}^{FSDP}/t_{compute} = \frac{\text{seq} \cdot \text{mbs}}{2} \cdot (DP-1) \cdot \frac{\text{peak\_flops}}{\text{peak\_bw}} \leq 1
|
548 |
+
$$
|
549 |
+
|
550 |
Overall it may sound like we significantly increase communication overhead, but thanks to **prefetching** we can start all-gathering weights for Layer n+1 while we do the current forward for Layer n which usually overlaps communication and computation as long as we don’t scale DP too much (as a rule of thumb: DP<512).
|
551 |
|
552 |
In terms of memory we can see that our equation now reached it’s final form of $\frac{2\Psi +2\Psi+k\Psi}{N_d}$ which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t specifically help with the intermediate activations that we discussed in the previous chapter. ZeRO is an orthogonal technique to the activation checkpointing and gradient accumulation we discussed in other chapters.
|
|
|
558 |
|
559 |
However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only reduce the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with e.g. short sequence length.
|
560 |
|
561 |
+

|
562 |
|
563 |
As model grow bigger and even a single layer may not fit in GPU, we need more tool in our distributed training toolbox to scale more.
|
564 |
|
|
|
593 |
|
594 |
In practice a small example of the operation looks like this:
|
595 |
|
596 |
+

|
597 |
|
598 |
Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.
|
599 |
|
600 |
Our first option is to use column-wise sharding (also called ***column-linear***): We'll copy the complete input matrices to each worker, requiring an operation called [***broadcast***](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21), and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an [***all-gather](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21)*** operation*.*
|
601 |
|
602 |
+

|
603 |
|
604 |
The second option is called row-wise sharding (also called ***row-linear***): As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, which needs a ***scatter*** operation rather than a broadcast as used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, thus requiring an all-reduce operation in this scenario.
|
605 |
|
606 |
We see here our fourth distributed primitive: ***s[catter](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21)***!
|
607 |
|
608 |
+

|
609 |
|
610 |
## Tensor Parallelism in a Transformer Block
|
611 |
|
|
|
613 |
|
614 |
The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks.
|
615 |
|
616 |
+

|
617 |
|
618 |
Now that we’ve found the most efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA).
|
619 |
|
|
|
621 |
|
622 |
It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.
|
623 |
|
624 |
+

|
625 |
|
626 |
Finally note that there is a tradeoff in terms of communication as we’ve added several distributed communication primitive directly in the computation path of our model. At the difference of ZeRO where we could prefetch, it can be harder to make these communication fully overlap with computations.
|
627 |
|
628 |
+

|
629 |
+
|
630 |
+
Forward pass in Tensor Parallelism
|
631 |
|
632 |
+
Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This *exposed communication* overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied.
|
633 |
|
634 |
+
Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, it introduces significant communication requirements that heavily depend on the network infrastructure. The inability to hide this particular AllReduce behind computation means it directly adds to the critical path of forward propagation.
|
635 |
+
|
636 |
+

|
637 |
|
638 |
Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.
|
639 |
|
640 |
In practice, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. As shown in the throughput plot above, we observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. This illustrates how communication costs can dominate at higher degrees of parallelism.
|
641 |
|
642 |
+
However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients, optimizer states and activations (to some extent) across GPUs. Let's examine this effect on a 70B parameter model:
|
643 |
|
644 |
+

|
645 |
|
646 |
+
As we can see, increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU. While tensor parallelism does help reduce activation memory in attention and feedforward layers by sharding the matrix multiplications across GPUs, we don't get the full memory benefits we could. This is because operations like layer normalization and dropout still require gathering the full activations on each GPU, partially negating the memory savings. We can do better by finding ways to parallelize these remaining operations as well.
|
647 |
|
648 |
> One interesting note about layer normalization in tensor parallel training - since each TP rank sees the same activations after the all-gather, the layer norm weights don't actually need an all-reduce to sync their gradients after the backward pass. They naturally stay in sync across ranks. However, for dropout operations, we must make sure to sync the random seed across TP ranks to maintain deterministic behavior.
|
649 |
>
|
|
|
674 |
|
675 |

|
678 |
|
679 |
in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
|
680 |
in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
|
|
|
692 |
- "f*" is a no-op because gradients are already duplicated across ranks
|
693 |
- "f" is an all-reduce to synchronize gradients
|
694 |
|
695 |
+
These operations "f" and "f*" are called **conjugate** pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.
|
696 |
|
697 |
For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.
|
698 |
|
|
|
748 |
s: unchanged | h: full (weight_out is full + **reduce-scatter** for correctness)
|
749 |
s: **reduce-scatter** to sharded |
|
750 |
|
|
|
|
|
751 |
You can find an example of implementation of both column and row linear TP in picotron:
|
752 |
[https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py](https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py)
|
753 |
|
754 |
+
By using sequence parallelism, we can achieve even greater activation memory savings, allowing us to push our batch size and sequence length further than what would be possible with tensor parallelism alone. Let's see what that means for our previous 70B model example:
|
|
|
|
|
755 |
|
756 |

|
757 |
|
758 |
+
Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).
|
759 |
|
760 |
+
If you’ve been paying close attention, you’ll notice that we’re talking about **4 comms ops in each layer** (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:
|
|
|
|
|
|
|
761 |
|
762 |
+

|
763 |
|
764 |
+
Forward pass in Tensor + Sequence Parallelism
|
765 |
|
766 |
+
Besides the fact that TP requires communications in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8)
|
767 |
|
768 |
+
> Note: Overlapping communication with computation for TP is an [active area of research](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487), with recent work like Domino [TODO: cite domino paper] exploring novel techniques to maximize this overlap. For example, Megatron-LM/Nanotron implement a partial overlapping of all-gather with FC1 computation, and we expect to see more innovations in this space as the field continues to evolve.
|
769 |
|
770 |
+
$$
|
771 |
|
772 |
+
t_{comm}/t_{compute} = \frac{1}{24h} \cdot (TP-1) \cdot \frac{\text{peak\_flops}}{\text{peak\_bw}} \leq 1
|
773 |
+
$$
|
774 |
|
775 |
+
As you might expect, this communication overhead becomes increasingly problematic as we scale up tensor parallelism. To illustrate this, let’s check throughput as we scale TP with SP for a 3B model:
|
776 |
|
777 |
+

|
778 |
|
779 |
+
Impact of combined Tensor and Sequence Parallelism (TP/SP) on a 3B model’s performance and memory utilization with 4096 seqlen: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.
|
780 |
|
781 |
Let’s summarize our observations:
|
782 |
|
|
|
784 |
- the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone
|
785 |
- the Torch memory fragmentation makes it hard for us to predict the exact peak reserved memory. For more details check memory_viz tool section. [TODO: add link]
|
786 |
|
|
|
|
|
787 |
**We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.**
|
788 |
|
789 |
However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.
|
|
|
798 |
|
799 |
Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:
|
800 |
|
801 |
+

|
802 |
|
803 |
Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.
|
804 |
|
|
|
806 |
|
807 |
The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model. Our focus here will be to reduce the activation memory footprint by splitting the long sequences, complementing parallelism strategies like TP which target the hidden dimension of the model.
|
808 |
|
809 |
+

|
810 |
|
811 |
Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just as in data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.
|
812 |
|
|
|
839 |
|
840 |
There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:
|
841 |
|
842 |
+

|
843 |
|
844 |
The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.
|
845 |
|
|
|
849 |
|
850 |
We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called [Zig-Zag attention](https://arxiv.org/pdf/2311.09431) and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.
|
851 |
|
852 |
+

|
853 |
|
854 |
At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.
|
855 |
|
856 |
We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:
|
857 |
|
858 |
+

|
859 |
|
860 |
Context Parallelism using AllGather implementation
|
861 |
|
862 |
+

|
863 |
|
864 |
Context Parallelism using All-to-All implementation
|
865 |
|
|
|
871 |
|
872 |
In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:
|
873 |
|
874 |
+

|
875 |
|
876 |
Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.
|
877 |
|
878 |
Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.
|
879 |
|
880 |
+

|
881 |
|
882 |
Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!
|
883 |
|
|
|
891 |
|
892 |
Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:
|
893 |
|
894 |
+

|
895 |
|
896 |
An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.
|
897 |
|
|
|
911 |
|
912 |
Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:
|
913 |
|
914 |
+

|
915 |
|
916 |
> Note: before the numbers indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure.
|
917 |
>
|
|
|
970 |
|
971 |
This schedule is called **one-forward-one-backward** **(1F1B)** as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:
|
972 |
|
973 |
+

|
974 |
|
975 |
The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for $p$ micro-batches instead of $m$ which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.
|
976 |
|
977 |
+

|
978 |
+
|
979 |
+
$$
|
980 |
+
|
981 |
+
t_{comm}^{PP}/t_{compute} = \frac{1}{32h \cdot \text{num\_layers\_in\_next\_pp}} \cdot \frac{\text{peak\_flops}}{\text{peak\_bw}} \leq 1
|
982 |
+
$$
|
983 |
|
984 |
A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.
|
985 |
|
|
|
1068 |
|
1069 |
This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.
|
1070 |
|
1071 |
+

|
1072 |
|
1073 |
As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of $v$, where $v$ is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes.
|
1074 |
|
|
|
1079 |
|
1080 |
So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by 𝑣 so it’s a trade off. In the following plot you can see several configurations for a PP setup with $p=8$, where the special case of $m=1, v=1$ corresponds to naive pipeline parallelism and the configurations with $v=1$ are AFAB or 1F1B setups and $v \neq 1$ are interleaved configurations.
|
1081 |
|
1082 |
+

|
1083 |
|
1084 |
Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in details in [https://arxiv.org/abs/2211.05953](https://arxiv.org/pdf/2211.05953).
|
1085 |
|
1086 |
You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.
|
1087 |
|
1088 |
+

|
1089 |
|
1090 |
However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!
|
1091 |
|
|
|
1095 |
|
1096 |
Let’s very quickly see how this can work by detailing briefly the [ZeroBubble](https://arxiv.org/abs/2401.10241) work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):
|
1097 |
|
1098 |
+

|
1099 |
|
1100 |
|
1101 |
|
1102 |
+

|
1103 |
|
1104 |
While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.
|
1105 |
|
1106 |
DeepSeek’s DualPipe propose an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph
|
1107 |
|
1108 |
+

|
1109 |
|
1110 |
The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the [ZeroBubble](https://arxiv.org/abs/2401.10241) paper for a discussion of the heuristics and algorithms to perform such a scheduling.
|
1111 |
|
|
|
1117 |
|
1118 |
So whereas Context parallelism
|
1119 |
|
1120 |
+
](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2053.png)
|
1121 |
|
1122 |
[https://arxiv.org/pdf/2407.06204](https://arxiv.org/pdf/2407.06204)
|
1123 |
|
|
|
1157 |
|
1158 |
# How to Find the Best Training Configuration
|
1159 |
|
1160 |
+

|
1161 |
|
1162 |
We’ve now covered all the parallelism techniques that are actually used to distribute and training larger models. There remain a general question: which ones should we choose and which ones are best combined? We touched a little bit on this at the end of the last section but in this section we will walk through the decision process step by step.
|
1163 |
|
|
|
1179 |
|
1180 |
Let’s try synthesize the decision process into a relatively simple tree structure:
|
1181 |
|
1182 |
+

|
1183 |
|
1184 |
To explain briefly, data parallelism is the most efficient method, and you should always prioritize it when memory is not a concern. If communication is not a concern and you can keep the BS/GPU at a big enough value to make good use of the GPU MatMul, ZeRO is an easy method to remove memory bottlenecks and stay close to a simple DP implementation. However, on larger clusters you’ll probably be able to make efficient use for more 4D parallelism. In this case, starting with tensor parallelism is the most direct way to reduce memory usage and is generally faster than pipeline parallelism within a single node(8 GPUs). However, in scenarios with long contexts, the primary memory usage will tend to shifts from model weights, gradients, and optimizer states to activation values. In such cases, context parallelism becomes more beneficial than pipeline parallelism. Note that this is not an exact recipe and you should think of this more as a starting point of hyperparameters to run your own benchmarks. For instance sometimes TP mixed with PP can be more efficient, even if TP<8 and ZeRO-1/2 can make sense to mix in with 4D parallelism as well.
|
1185 |
|
|
|
1203 |
|
1204 |
On the compute side, GPUs consist of an array of compute units called **Streaming Multiprocessors** (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see [https://resources.nvidia.com/en-us-tensor-core](https://resources.nvidia.com/en-us-tensor-core) for details), each capable of handling multiple threads simultaneously.
|
1205 |
|
1206 |
+
.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2056.png)
|
1207 |
|
1208 |
Original figure from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing).
|
1209 |
|
1210 |
The memory side is also highly hierarchical with several layers of cache and memory: **Registers** are the smallest units and are private to the threads during executions, **Shared Memory** and **L1 cache are** shared between the threads running on a single SM, higher up is the **L2 cache** shared by all SMs, finally there is the **Global Memory** which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.
|
1211 |
|
1212 |
+
](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2057.png)
|
1213 |
|
1214 |
Original figure from [https://www.youtube.com/watch?v=ZQKMZIP3Fzg](https://www.youtube.com/watch?v=ZQKMZIP3Fzg)
|
1215 |
|
|
|
1219 |
|
1220 |
To run the kernel, you will also need a specific code part (called **host code**) which is executed on the **CPU**/host and will take care of preparing data allocations and loading data and code.
|
1221 |
|
1222 |
+

|
1223 |
|
1224 |
Figure 5: Host code for a CUDA kernel for adding two vectors from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing)
|
1225 |
|
1226 |
+

|
1227 |
|
1228 |
Figure 6: Device code containing the definition of the vector addition kernel from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing)
|
1229 |
|
|
|
1262 |
|
1263 |
The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns) :
|
1264 |
|
1265 |
+

|
1266 |
|
1267 |
However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the triton kernel generated by `@torch.compile` . To do so, you simply need to set the environment variable `TORCH_LOGS` to “output_code” :
|
1268 |
|
|
|
1321 |
|
1322 |
When we benchmark the generated kernel using `triton.testing.Benchmark` we have the following performance :
|
1323 |
|
1324 |
+

|
1325 |
|
1326 |
This standalone kernel demonstrates superior performance with smaller sizes compared to `@torch.compile` but this is likely here just an artifact from the compilation time of torch. compile. In any case, instead of starting from scratch, we can focus on optimizing this generated kernel, saving us time in the process.
|
1327 |
|
|
|
1361 |
|
1362 |
Here’s an excellent visualization of the kernel from this fantastic [blogpost](https://siboehm.com/articles/22/CUDA-MMM) :
|
1363 |
|
1364 |
+

|
1365 |
|
1366 |
However, when profiling this kernel with a tool like `ncu`, we can see issues, including low memory throughput and uncoalesced memory accesses.
|
1367 |
|
1368 |
+

|
1369 |
|
1370 |
+

|
1371 |
|
1372 |
The reason for this is that in this kernel, two threads in the same block with Thread IDs `(0, 0)` and `(1, 0)` (which will end up in the same warp) will both load from the same column of matrix `B` but different rows of matrix `A`. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with `i = 0`, thread `(0, 0)` will load $A_{0,0}$, and thread `(1, 0)` will load $A_{1,0}$. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.
|
1373 |
|
1374 |
+

|
1375 |
|
1376 |
To improve our kernel we can change the way the coordinates x and y are calculated like the following :
|
1377 |
|
|
|
1392 |
|
1393 |
When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and **the GPU's memory throughput has increased by approximately 10 times**.
|
1394 |
|
1395 |
+

|
1396 |
|
1397 |
We also notice that the execution time of the kernel **decreases by 10x** !
|
1398 |
|
|
|
1406 |
|
1407 |
In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size `BLOCK_SIZE_M` by `BLOCK_SIZE_K`) and a tile of matrix B (of size `BLOCK_SIZE_K` by `BLOCK_SIZE_N`). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.
|
1408 |
|
1409 |
+
](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2067.png)
|
1410 |
|
1411 |
From [https://cnugteren.github.io/tutorial/pages/page4.html](https://cnugteren.github.io/tutorial/pages/page4.html)
|
1412 |
|
|
|
1450 |
|
1451 |
The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:
|
1452 |
|
1453 |
+

|
1454 |
|
1455 |
The meaning of the states can be found in the [Profiling Guide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference), specifically in the **Warp Stall Reasons** section. There we can read that :
|
1456 |
|
|
|
1475 |
|
1476 |
A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:
|
1477 |
|
1478 |
+

|
1479 |
|
1480 |
Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!
|
1481 |
|
1482 |
The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of $O$ directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.
|
1483 |
|
1484 |
+
)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2070.png)
|
1485 |
|
1486 |
From the FLASH-ATTENTION paper ([https://arxiv.org/pdf/2205.14135](https://arxiv.org/pdf/2205.14135))
|
1487 |
|
|
|
1539 |
|
1540 |
Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:
|
1541 |
|
1542 |
+

|
1543 |
|
1544 |
We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.
|
1545 |
|
1546 |
How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:
|
1547 |
|
1548 |
+

|
1549 |
|
1550 |
We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.
|
1551 |
|
|
|
1575 |
|
1576 |
Recent research - including [FP8-LM](https://arxiv.org/abs/2310.18313), [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8#torchaofloat8), and [DeepSeek-V3](https://arxiv.org/abs/2412.19437) - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.
|
1577 |
|
1578 |
+

|
1579 |
|
1580 |
As [[Wortsman et al.]](https://arxiv.org/abs/2309.14322) observed, instability increases as learning rates rise for a fixed model size, making FP8 pretraining particularly tricky.
|
1581 |
|
|
|
1713 |
|
1714 |
The general setup is that we have a number of independent nodes which could be CPU cores, GPUs, or compute nodes. Each performs some computation and then we want to communicate the result or parts of it to the other nodes for the next computation step (t+1).
|
1715 |
|
1716 |
+

|
1717 |
|
1718 |
Maybe we need to send the result from one node to all other nodes, or we need to sum all the intermediate results from each node to report the overall result. Usually, there is one node with an elevated status that plays a central role, here denoted with `root` that is the target or source of some operations. Let’s start with one of the simplest primitives: a broadcast operation.
|
1719 |
|
|
|
1721 |
|
1722 |
A very common pattern is that you have some data on Node 1 and you want to share it with all the other nodes so they can do some computation with the data. The broadcast operation does just that:
|
1723 |
|
1724 |
+

|
1725 |
|
1726 |
Collective operations are natively provided by PyTorch so we can easily write a small example that demonstrates how broadcasting works. We first need to initialize a process group with `dist.initi_process_group` which sets up the communication backend (we’ll talk about NCCL later), it determines how many workers (aka nodes) exists and assigns a rank to each one (which we can get with `dist.get_rank`). Finally, it establishes a connection between the workers.
|
1727 |
|
|
|
1766 |
|
1767 |
Reduce patterns are among the most fundamental patterns in distributed data processing. The idea is that you want to combine the data present on each node through a function `f()` which can be for instance summation or averaging. In the Reduce paradigm the result is sent to the root node only, whereas in the AllReduce case the result is broadcasted to all nodes:
|
1768 |
|
1769 |
+

|
1770 |
|
1771 |
Of course no magic “free flying” node that can perform this operation and generally each node does a partial computation in a ring or tree structure of the nodes. Here is a simple example: let’s say we need to compute a sum of numbers on each nodes and our nodes are connected in a ring pattern. The first node sends its number to a neighbour which adds its number to the received number before forwarding it to the next neighbour. At the end of a round along the ring of nodes, the first node will receive the total sum.
|
1772 |
|
|
|
1861 |
|
1862 |
Gather and AllGather are quite similar to the Broadcast in that they allow distributing data among node without modification. The main difference to Broadcast is that there is not one value we need to share from one node to all other nodes but each node has an individual chunk of data that we want to either gather all data on one node (in case of Gather) or gather all data on all nodes (in the case of AllGather). A picture being worth 1000 words, let’s take a look:
|
1863 |
|
1864 |
+

|
1865 |
|
1866 |
Note that the dashed lines indicate that some data actually doesn’t move at all (since it’s already present on the node).
|
1867 |
|
|
|
1935 |
|
1936 |
The ReduceScatter pattern is slightly more complex: imagine you apply an operation like in the Reduce case but instead of moving the result to just one node we also distribute it evenly to all nodes:
|
1937 |
|
1938 |
+

|
1939 |
|
1940 |
The Scatter operation is written in code as the opposite of the Gather: instead of preparing a list of tensors as target we prepare the source data as a list of tensors we want to distribute. We also need to specify the `src`:
|
1941 |
|
|
|
2010 |
|
2011 |
The Barrier is a simple operation to synchronize all nodes. A barrier is not lifted until all nodes have reached it. Then only are they allowed to continue with further computations:
|
2012 |
|
2013 |
+

|
2014 |
|
2015 |
We can easily simulate delayed nodes by setting up a different sleep time on each node and see how long it takes for all of them to pass the barrier:
|
2016 |
|
|
|
2123 |
|
2124 |
This would print aggregated profiling results sorted by the total CUDA time, and the output would be:
|
2125 |
|
2126 |
+

|
2127 |
|
2128 |
You can also try to inspect the trace as we previously mentioned on `chrome://tracing/`
|
2129 |
|
|
|
2132 |
|
2133 |
After zooming in, you can observe the flow of operations when calling `layer_norm` in this trace:
|
2134 |
|
2135 |
+

|
2136 |
|
2137 |
The sequence begins in the CPU (the upper section) with `aten::layer_norm`, progressing to `aten::native_layer_norm`, and then transitioning to `cudaLaunchKernel`. From there, we move on to the GPU, where the `vectorized_layer_norm_kernel` kernel is called.
|
2138 |
|
|
|
2153 |
|
2154 |
and open the file `output.ncu-rep` with Nsight Compute, you will have a view that looks like this :
|
2155 |
|
2156 |
+

|
2157 |
|
2158 |
With clear warnings about compute and memory utilization, and how to make the kernel better in balancing compute and memory and achieve maximal occupancy.
|
2159 |
|
|
|
2235 |
|
2236 |
The chain rule applies here since the loss (L) depends directly on the output (Y). This equation is telling us that to get the gradient of the loss with respect to our input (dL/dX), we multiply the gradient of the loss with respect to the output (dL/dY) by our weight matrix (W).
|
2237 |
|
2238 |
+

|
2239 |
|
2240 |
Likewise, we can use chain rule to compute the gradient w.r.t to the weight:
|
2241 |
|
|
|
2243 |
\frac{dL}{dW} = \frac{dL}{dY} \frac{dY}{dW} = \frac{dL}{dY} X
|
2244 |
$$
|
2245 |
|
2246 |
+

|
2247 |
|
2248 |
Here is a snippet of code to clarify all the concepts above:
|
2249 |
|
|
|
2322 |
example_column_row_linear()
|
2323 |
```
|
2324 |
|
2325 |
+

|
2326 |
|
2327 |
**TODO** add these illustrations somewhere? I found them helpful:
|
2328 |
|
2329 |
+

|
2330 |
|
2331 |
+

|
2332 |
|
2333 |
## A3: ZeRO-R
|
2334 |
|
|
|
2473 |
|
2474 |
### Interconnect
|
2475 |
|
2476 |
+

|
2477 |
|
2478 |
## How to profile your code
|
2479 |
|
|
|
2509 |
|
2510 |
After running this code, you will find `*.trace.json` files under the `profiler_out_dir`. To visualize the results, the easiest way is to open Google Chrome, go to `chrome://tracing/`, and drag the file into it. This will allow you to view the profiling results. To get more details, we invite you to check out the amazing [**tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)** created by PyTorch.
|
2511 |
|
2512 |
+

|
2513 |
|
2514 |
## Formulas for compute / comms the balanhe balance
|
2515 |
|
|
|
2622 |
|
2623 |
```
|
2624 |
|
2625 |
+

|
2626 |
|
2627 |
## Integrating Context Parallelism with TP/SP
|
2628 |
|
|
|
2635 |
3. **Replace standard attention with ring attention:** During the forward pass, each TP rank relies on the ring attention to compute the correct attention results during both the forward and backward passes. So all CP ranks within TP=0 for example need to all-gather the full KV sequence and calculate attention, but we store only the KV of a sequence chunk to reduce memory activations by CP.
|
2636 |
|
2637 |

|
2639 |
|
2640 |
TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
|
2641 |
TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank
|
|
|
2648 |
|
2649 |
However, through extensive experimentation, we identified two effective training recipes that allowed us to **fully pretrain a 1B LLaMA model in FP8**, covering both the forward and backward passes, while using an FP8 optimizer. More importantly, our approach successfully matched LLaMA-2’s pretraining learning rate. The result?
|
2650 |
|
2651 |
+

|
2652 |
|
2653 |
A loss curve that perfectly matches mixed-precision bfloat16 (bfloat16 with FP32 master weights as the baseline). We successfully tested this to train a 1B LLaMA up to 100B tokens and a 7B LLaMA up to 25B tokens.
|
2654 |
|
|
|
2686 |
|
2687 |
**Non-overlapping:** If we don't overlap the communication and computation, each computation (represented by the purple block) can only begin after the communication (green block) is complete and total time is the sum of communication and computation.
|
2688 |
|
2689 |
+

|
2690 |
|
2691 |
**Overlapping:** However, if we manage to launch communication and computation in parallel, we eliminate the waiting time! Now we can see that the computation (green block) is launched immediately, one after the other. In this case the total time is *only* the sum of computations.
|
2692 |
|
2693 |
+

|
2694 |
|
2695 |
Context parallelism has helped us going past the intra-node interconnect bottleneck, which blocked us from scaling TP across nodes. However, as you probably noted, it only helps reducing the memory constraints if the activation memory dominates the memory budget due to long sequences. What if we are not working on super long sequences and the model weights alone are too big for a single node?
|
2696 |
|
2697 |
Well it turns out we have an other –quite different– option called pipeline parallelism (PP) which the time has come to explore now.
|
2698 |
|
2699 |
+
[TODO: comment from Nouamane on comms overlapping with DP 512]
|
2700 |
+
|
2701 |
+
## seq parallel profiling
|
2702 |
+
|
2703 |
+
TODO: remove, Profiling:
|
2704 |
+
|
2705 |
+
- TP
|
2706 |
+
|
2707 |
+

|
2708 |
+
|
2709 |
+
- Seq Parall
|
2710 |
+
|
2711 |
+

|
2712 |
+
|
2713 |
+
Allreduce takes almost double the duration (900us) of reducescatter and allgather (500us)
|
dist/bibliography.bib
CHANGED
@@ -403,4 +403,13 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
|
|
403 |
archivePrefix={arXiv},
|
404 |
primaryClass={cs.LG},
|
405 |
url={https://arxiv.org/abs/2205.05198},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
}
|
|
|
403 |
archivePrefix={arXiv},
|
404 |
primaryClass={cs.LG},
|
405 |
url={https://arxiv.org/abs/2205.05198},
|
406 |
+
}
|
407 |
+
@misc{wang2024domino,
|
408 |
+
title={Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping},
|
409 |
+
author={Guanhua Wang and Chengming Zhang and Zheyu Shen and Ang Li and Olatunji Ruwase},
|
410 |
+
year={2024},
|
411 |
+
eprint={2409.15241},
|
412 |
+
archivePrefix={arXiv},
|
413 |
+
primaryClass={cs.DC},
|
414 |
+
url={https://arxiv.org/abs/2409.15241},
|
415 |
}
|
dist/index.html
CHANGED
@@ -468,16 +468,125 @@
|
|
468 |
|
469 |
<h2>Data Parallelism</h2>
|
470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
<h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
|
472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
<h4><strong>Second optimization:</strong> Bucketing gradients</h4>
|
474 |
|
475 |
-
<
|
476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
<h3>Revisit global batch size</h3>
|
478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
<h3>Our journey up to now</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
<h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
|
482 |
|
483 |
<h4>Memory usage revisited</h4>
|
@@ -488,12 +597,253 @@
|
|
488 |
|
489 |
<h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
|
490 |
|
|
|
491 |
<h2>Tensor Parallelism</h2>
|
492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
<h3>Tensor Parallelism in a Transformer Block</h3>
|
494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
<h3>Sequence Parallelism</h3>
|
496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
<h2>Context Parallelism</h2>
|
498 |
|
499 |
<h3>Introducing Context Parallelism</h3>
|
|
|
468 |
|
469 |
<h2>Data Parallelism</h2>
|
470 |
|
471 |
+
<p>The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism. </p>
|
472 |
+
|
473 |
+
<p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
|
474 |
+
|
475 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
476 |
+
|
477 |
+
<p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
|
478 |
+
|
479 |
+
<aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].</aside>
|
480 |
+
|
481 |
+
<p>TODO: embed naive DP: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60</a></p>
|
482 |
+
|
483 |
+
<p>TODO: embed bucket DP: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171</a></p>
|
484 |
+
|
485 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
486 |
+
|
487 |
+
<p>A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is <strong>A BIG NO!</strong> Because we don’t want our GPUs to stay idle while communication is happening.</p>
|
488 |
+
|
489 |
+
<p>Instead we should try to overlap communication and computation whenever possible so that they happen at the same time as much as possible.</p>
|
490 |
+
|
491 |
+
<p>Let’s see three optimizations that are done in practice for this! </p>
|
492 |
+
|
493 |
<h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
|
494 |
+
|
495 |
+
<p>The main drawback of the naive DDP approach we’ve just described is that after the backward pass (<em>computation</em>), we have to wait for gradient synchronization (<em>communication</em>) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!</p>
|
496 |
+
|
497 |
+
<p>As shown in the figure above, the gradients (red boxes) for a layer can be gathered and summed even before the gradients from earlier layers (red boxes to the left) have been computed. For example, as soon as the backward pass of the last layer is complete (last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.</p>
|
498 |
+
|
499 |
+
<p>This can be achieved in pytorch by attaching an <em>all-reduce hook function</em> to each parameter. An all-reduce operation is triggered as soon as the gradient for that parameter is ready, while gradients for other parameters are still being computed. This approach overlaps most of the all-reduce operations with gradient calculations, thereby improving efficiency. Here's a simple function to attach a hook:</p>
|
500 |
+
|
501 |
+
<d-code block language="python">
|
502 |
+
def register_backward_hook(self, hook):
|
503 |
+
"""
|
504 |
+
Registers a backward hook for all parameters of the model that
|
505 |
+
require gradients.
|
506 |
+
"""
|
507 |
+
for p in self.module.parameters():
|
508 |
+
if p.requires_grad is True:
|
509 |
+
p.register_post_accumulate_grad_hook(hook)</d-code>
|
510 |
+
|
511 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png"/></p>
|
512 |
+
|
513 |
+
<p>Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. </p>
|
514 |
+
|
515 |
+
<p>This is our first example of “<em>overlapping computation and communication</em>” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency. Let's have a look how we can further improve the DP efficiency!</p>
|
516 |
+
|
517 |
+
|
518 |
<h4><strong>Second optimization:</strong> Bucketing gradients</h4>
|
519 |
|
520 |
+
<p>We can even go further with optimizing DP. For a given number of parameters to synchronize, GPU operations like collective communications are often more efficient when performing few calls on large tensors rather than many calls on smaller tensors. Therefore, instead of performing independent all-reduce for each gradient, we can group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket. Think of it like packing items into boxes before shipping—it's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.</p>
|
521 |
+
|
522 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
523 |
+
|
524 |
+
<h4><strong>Third optimization: </strong>Interplay with gradient accumulation</h4>
|
525 |
+
|
526 |
+
<p>As we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with <code>optimizer.step()</code>. When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.</p>
|
527 |
+
|
528 |
+
<p>In a naive version, an all-reduce operation is automatically triggered after each backward pass during the accumulation, which is sub-optimal as a single reduce after the final step would have the same effect while reducing overhead.</p>
|
529 |
+
|
530 |
+
<p>In PyTorch, this is typically solved by adding a <a href="https://github.com/pytorch/pytorch/blob/5ea67778619c31b13644914deef709199052ee55/torch/nn/parallel/distributed.py#L1408-L1435"><code>model.no_sync()</code></a> decorator, which disables gradient synchronization, on the backward passes which don’t need reduction.</p>
|
531 |
+
|
532 |
+
<aside>When performing communication operations, tensors must be contiguous in memory. To avoid redundant memory copies during communication, ensure that tensors that will be communicated are stored contiguously in memory. Sometimes we need to allocate additional continuous buffers of the size of activations or model parameters specifically for communication, which contributes to the peak memory usage during training.</aside>
|
533 |
+
|
534 |
<h3>Revisit global batch size</h3>
|
535 |
+
<p>Let’s update our batch size equation with our newly learned Data Parallelism and Gradient Accumulation parameters:</p>
|
536 |
+
|
537 |
+
<d-math block>
|
538 |
+
bs = gbs = mbs \times grad\_acc
|
539 |
+
</d-math>
|
540 |
+
<p>Where <d-math>grad\_acc</d-math> is the number of gradient accumulation steps and DP is the number of parallel instances used for data parallelism.</p>
|
541 |
+
|
542 |
+
<p>Given a targeted global batch size, we can thus trade gradient accumulation steps for data-parallel processes to speed up training. In practice, people tend to maximize the number of data-parallel nodes (DP) over gradient accumulation as much as possible since it's inherently parallel, unlike the sequential nature of gradient accumulation. Gradient accumulation is then added on top of data parallelism to achieve the target global batch size when scaling data parallelism alone is not sufficient before you run out of GPUs.</p>
|
543 |
+
|
544 |
+
<aside>A good resource for further reading on Data Parallelism is <a href="https://siboehm.com/articles/22/data-parallel-training">https://siboehm.com/articles/22/data-parallel-training</a>.
|
545 |
+
</aside>
|
546 |
+
|
547 |
+
<p>Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 4 more dimensions).</p>
|
548 |
+
|
549 |
<h3>Our journey up to now</h3>
|
550 |
+
<p>Let’s quickly summarize what we’ve seen up to now and how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:</p>
|
551 |
+
|
552 |
+
<ol>
|
553 |
+
<li>We should first determine the best (global) batch size in tokens (<code>GBST</code>) either by consulting literature or running experiments measuring model convergence.</li>
|
554 |
+
<li>We then select a sequence length for training, again by either consulting literature or running experiments. Generally, 2-8k tokens work reliably well for the evaluations we have today (we won’t dive in training recipes here but teams usually increase the sequence at the end of the training, adding some longer-context data samples in the mix to reach the longer context size of today).</li>
|
555 |
+
<li>We now know the batch size (gbs). We can find the maximum local batch size (mbs) on a single GPU by increasing the local batch size until we run out of memory.</li>
|
556 |
+
<li>Finally, we determine the number of available GPUs for our target DP. The ratio of GBS to DP gives us the remaining number of gradient accumulation steps needed for the desired GBS. </li>
|
557 |
+
</ol>
|
558 |
+
|
559 |
+
<aside>For instance DeepSeek and Llama models are trained with a 4k tokens sequence length during the main pretraining phase.</aside>
|
560 |
+
|
561 |
+
<aside>The reason 2-8k work well for pretraining is that documents that are longer are very rare on the web. See this <a href="https://www.harmdevries.com/post/context-length/">Harm’s blogpost</a> for a detailed analysis.
|
562 |
+
</aside>
|
563 |
+
|
564 |
+
<p>If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs a.k.a GPU-rich 🤑 (!), we can either choose to not use all our GPUs, explore a larger global batch size or test if a lower MBS will speed up training. In the latter case we’ll end up prioritizing throughput over individual GPU compute efficiency, using a smaller MBS than possible in order to speed up training.</p>
|
565 |
+
|
566 |
+
<p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!</p>
|
567 |
+
|
568 |
+
<aside>Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by <em>ring latency</em> (time required for a signal to propagate once around the ring) **which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
|
569 |
+
</aside>
|
570 |
+
|
571 |
+
<p>While data parallelism cleverly overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. As we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly. The end result? We get less and less efficient returns from each additional GPU we add to the system:</p>
|
572 |
+
|
573 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png"/></p>
|
574 |
+
|
575 |
+
<p>As expected, we can also see that the memory usage per GPU is not affected by adding more DP ranks for training.</p>
|
576 |
+
|
577 |
+
<p><strong>We’ve explored data parallelism, our first (simple) strategy to scale training across more GPUs. It works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!</strong></p>
|
578 |
|
579 |
+
<p>The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs<em>=1)</em> into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated: </p>
|
580 |
+
|
581 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
582 |
+
|
583 |
+
<aside>Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)</aside>
|
584 |
+
|
585 |
+
<p>Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some of these tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices!</p>
|
586 |
+
|
587 |
+
<p>There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined! The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!</p>
|
588 |
+
|
589 |
+
|
590 |
<h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
|
591 |
|
592 |
<h4>Memory usage revisited</h4>
|
|
|
597 |
|
598 |
<h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
|
599 |
|
600 |
+
|
601 |
<h2>Tensor Parallelism</h2>
|
602 |
+
|
603 |
+
<p>So we have sharded the model’s parameters, gradients and optimizers states with ZeRO but we hit a limit once activation memory overtakes our memory budget. Welcome Tensor Parallelism (TP), a method which shards weights, gradients, and optimizers states as well as activations and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how Tensor Parallel works with simple matrix multiplications.</p>
|
604 |
+
|
605 |
+
<p>Tensor Parallelism leverages the mathematical properties of matrix multiplication <d-math>A \times B</d-math>. To understand how it works, let's examine two fundamental equations that make this parallelization possible:</p>
|
606 |
+
|
607 |
+
<d-math block>
|
608 |
+
\begin{aligned}
|
609 |
+
&\text{1.} \quad A\cdot B = A \cdot \begin{bmatrix} B_1 & B_2 & \cdots \end{bmatrix} = \begin{bmatrix} AB_1 & AB_2 & \cdots \end{bmatrix} \\
|
610 |
+
&\text{2.} \quad A\cdot B =\begin{bmatrix} A_1 & A_2 & \cdots \end{bmatrix} \begin{bmatrix} B_1 \\ B_2 \\ \vdots \end{bmatrix} = \sum_{i=1}^n A_i B_i
|
611 |
+
\end{aligned}
|
612 |
+
</d-math>
|
613 |
+
|
614 |
+
<p>This means that we can compute matrix product by either 1) multiplying each column of <d-math>B</d-math> individually or 2) multiplying each row individually and combining the results. In a neural network, the matrix multiplication is more often represented in the following format: <d-math>X \times W</d-math>, where:</p>
|
615 |
+
|
616 |
+
<ul>
|
617 |
+
<li>X represents the input or activation values</li>
|
618 |
+
<li>W represents the weight of the <code>nn.Linear</code></li>
|
619 |
+
</ul>
|
620 |
+
|
621 |
+
<p>In practice a small example of the operation looks like this:</p>
|
622 |
+
|
623 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
624 |
+
|
625 |
+
<p>Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.</p>
|
626 |
+
|
627 |
+
<p>Our first option is to use column-wise sharding (also called <strong><em>column-linear</em></strong>): We'll copy the complete input matrices to each worker, requiring an operation called <strong><em>broadcast</em></strong>, and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an <strong><em>all-gather</em></strong> operation.</p>
|
628 |
+
|
629 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
630 |
+
|
631 |
+
<p>The second option is called row-wise sharding (also called <strong><em>row-linear</em></strong>): As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, which needs a <strong><em>scatter</em></strong> operation rather than a broadcast as used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, thus requiring an all-reduce operation in this scenario.</p>
|
632 |
+
|
633 |
+
<p>We see here our fourth distributed primitive: <strong><em>scatter</em></strong>!</p>
|
634 |
+
|
635 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
636 |
+
|
637 |
<h3>Tensor Parallelism in a Transformer Block</h3>
|
638 |
|
639 |
+
<p>To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks : Feedforward layers (MLP) and Multi-Head Attention (MHA). We can apply tensor parallelism to both.</p>
|
640 |
+
|
641 |
+
<p>The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks.</p>
|
642 |
+
|
643 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
644 |
+
|
645 |
+
<p>Now that we’ve found the most efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA).</p>
|
646 |
+
|
647 |
+
<p>We can generally follow a similar approach where Q, K, and V matrices are split in a column-parallel fashion, and the output projection is split along the row dimension. With multi-head attention, the column-parallel approach has a very natural interpretation: each worker computes the attention for an individual or a subset of heads. The same approach works as well for <a href="https://arxiv.org/abs/1911.02150"><strong><em>multi-query</em></strong> (MQA)</a> or <a href="https://arxiv.org/abs/2305.13245"><strong><em>grouped query attention</em></strong> (GQA)</a> where key and values are shared between queries. </p>
|
648 |
+
|
649 |
+
<p>It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.</p>
|
650 |
+
|
651 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
652 |
+
|
653 |
+
<p>Finally note that there is a tradeoff in terms of communication as we’ve added several distributed communication primitive directly in the computation path of our model. At the difference of ZeRO where we could prefetch, it can be harder to make these communication fully overlap with computations. </p>
|
654 |
+
|
655 |
+
<p><img alt="Forward pass in Tensor Parallelism" src="/assets/images/placeholder.png" /></p>
|
656 |
+
|
657 |
+
<p>Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This <em>exposed communication</em> overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied. </p>
|
658 |
+
|
659 |
+
<p>Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, it introduces significant communication requirements that heavily depend on the network infrastructure. The inability to hide this particular AllReduce behind computation means it directly adds to the critical path of forward propagation.</p>
|
660 |
+
|
661 |
+
<p><img alt="Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training." src="/assets/images/placeholder.png" /></p>
|
662 |
+
|
663 |
+
<p>Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.</p>
|
664 |
+
|
665 |
+
<p>In practice, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. As shown in the throughput plot above, we observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. This illustrates how communication costs can dominate at higher degrees of parallelism.</p>
|
666 |
+
|
667 |
+
<p>However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients, optimizer states and activations (to some extent) across GPUs. Let's examine this effect on a 70B parameter model:</p>
|
668 |
+
|
669 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
670 |
+
|
671 |
+
<p>As we can see, increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU. While tensor parallelism does help reduce activation memory in attention and feedforward layers by sharding the matrix multiplications across GPUs, we don't get the full memory benefits we could. This is because operations like layer normalization and dropout still require gathering the full activations on each GPU, partially negating the memory savings. We can do better by finding ways to parallelize these remaining operations as well.</p>
|
672 |
+
|
673 |
+
<aside>One interesting note about layer normalization in tensor parallel training - since each TP rank sees the same activations after the all-gather, the layer norm weights don't actually need an all-reduce to sync their gradients after the backward pass. They naturally stay in sync across ranks. However, for dropout operations, we must make sure to sync the random seed across TP ranks to maintain deterministic behavior.
|
674 |
+
</aside>
|
675 |
+
|
676 |
+
<p>This raises an interesting question - could we extend tensor parallelism to these remaining operations as well? Indeed, it's possible to parallelize layer norm, dropout and other operations too, which we'll explore next.</p>
|
677 |
+
|
678 |
<h3>Sequence Parallelism</h3>
|
679 |
|
680 |
+
<p>In regions where we apply tensor parallelism (TP), like attention and feedforward layers, each GPU only needs to operate on a portion of the hidden dimension since the weights are sharded. However, operations like layer norm or dropout (which is not used a lot anymore in LLM) require access to the full hidden dimension to compute correctly.</p>
|
681 |
+
|
682 |
+
<p>Rather than gathering the full hidden dimension on each GPU (which would defeat the memory benefits of TP), we can instead shard these operations along the sequence length dimension. This approach is called <strong>sequence parallelism (SP)</strong>.</p>
|
683 |
+
|
684 |
+
<aside>Note that the term Sequence Parallelism is a bit overloaded: the Sequence Parallelism in this section is tightly coupled to Tensor Parallelism and applies to dropout and layer norm operation. However, when we will move to longer sequences the attention computation will become a bottleneck, which calls for techniques such as Ring-Attention, which are sometimes also called <em>Sequence Parallelism</em> but we’ll refer to them as <em>Context Parallelism</em> to differentiate the two approaches. So each time you see sequence parallelism, remember that it is used together with tensor parallelism (in contrast to context parallelism, which can be used independently).</aside>
|
685 |
+
|
686 |
+
<p>Sequence parallelism (SP) involves splitting the activations and computations for the parts of the model not handled by tensor parallelism (TP) such as Dropout and LayerNorm, but along the input sequence dimension rather than across hidden dimension. This is needed because these operations require access to the full hidden dimension to compute correctly. For example, LayerNorm needs the full hidden dimension to compute mean and variance:</p>
|
687 |
+
|
688 |
+
<d-math block>
|
689 |
+
\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
|
690 |
+
</d-math>
|
691 |
+
|
692 |
+
<p>where <d-math>\mu = \text{mean}(x)</d-math> and <d-math>\sigma^2 = \text{var}(x)</d-math> are computed across hidden dimension <d-math>h</d-math>.</p>
|
693 |
+
|
694 |
+
<p>So even though these operations are computationally cheap, they still require significant activation memory since they need the complete hidden dimension. SP allows us to shard this <strong>memory</strong> burden across GPUs by splitting along the sequence dimension instead.</p>
|
695 |
+
|
696 |
+
<p>In practice we’ll go from the left diagram to the right:</p>
|
697 |
+
|
698 |
+
<p><img alt=" in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
|
699 |
+
in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
|
700 |
+
SP region needs full hidden_dim" src="/assets/images/placeholder.png" /></p>
|
701 |
+
|
702 |
+
<p>in forward: f = no-op ; f<em> = all-reduce ; g = all-gather ; g</em> = reduce-scatter in backward: f = all-reduce ; f<em> = no-op ; g = reduce-scatter ; g</em> = all-gather SP region needs full hidden_dim</p>
|
703 |
+
|
704 |
+
<p>The diagram shows how we transition between tensor-parallel and sequence-parallel regions using different collective operations (labeled "f" and "g"). The key challenge is managing these transitions efficiently while keeping memory usage low and maintaining correctness.</p>
|
705 |
+
|
706 |
+
<p>In the forward pass:</p>
|
707 |
+
<ul>
|
708 |
+
<li>"f" is a no-op (no operation) because activations are already duplicated across ranks</li>
|
709 |
+
<li>"f*" is an all-reduce to synchronize activations and ensure correctness</li>
|
710 |
+
</ul>
|
711 |
+
<p>In the backward pass:</p>
|
712 |
+
<ul>
|
713 |
+
<li>"f*" is a no-op because gradients are already duplicated across ranks</li>
|
714 |
+
<li>"f" is an all-reduce to synchronize gradients</li>
|
715 |
+
</ul>
|
716 |
+
|
717 |
+
<p>These operations "f" and "f<em>" are called </em><em>conjugate</em>* pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.</p>
|
718 |
+
|
719 |
+
<p>For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
|
720 |
+
|
721 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
722 |
+
|
723 |
+
<p>So what is actually happening here? As a famous LLM would say, let’s take it step-by-step:</p>
|
724 |
+
|
725 |
+
<p><strong>Initial LayerNorm (SP Region)</strong></p>
|
726 |
+
<ul>
|
727 |
+
<li>Input tensors X1<em> and X2</em> (b,s/2,h) enter LayerNorm, already split across sequence dimension</li>
|
728 |
+
<li>Each GPU computes LayerNorm independently on its sequence chunk and give Y1<em> and Y2</em></li>
|
729 |
+
</ul>
|
730 |
+
<p><strong>First Transition (SP → TP)</strong></p>
|
731 |
+
<ul>
|
732 |
+
<li>"g" operation (all-gather) combines Y1<em> and Y2</em> back to full sequence length</li>
|
733 |
+
<li> Restores Y (b,s,h) since column linear layer needs full hidden dimension h</li>
|
734 |
+
</ul>
|
735 |
+
<p><strong>First Linear Layer (TP Region)</strong></p>
|
736 |
+
<ul>
|
737 |
+
<li>A1 is a column-linear layer, so it splits Y along the hidden dimension</li>
|
738 |
+
<li>GeLU is applied independently on each GPU</li>
|
739 |
+
<li>Z1* is (b,s,h/2)</li>
|
740 |
+
</ul>
|
741 |
+
<p><strong>Second Linear Layer (TP Region)</strong></p>
|
742 |
+
<ul>
|
743 |
+
<li>B1 is a row-linear layer, so it restores the hidden dimension</li>
|
744 |
+
<li>W1 is (b,s,h)</li>
|
745 |
+
</ul>
|
746 |
+
<p><strong>Final Transition (TP → SP)</strong></p>
|
747 |
+
<ul>
|
748 |
+
<li>"g*" operation (reduce-scatter) which reduces for previous row-linear correctness while scattering along sequence dimension</li>
|
749 |
+
<li>W1* is (b,s/2,h)</li>
|
750 |
+
</ul>
|
751 |
+
|
752 |
+
<p>A key advantage of sequence parallelism is that it reduces the maximum activation size we need to store. In tensor parallelism alone, we had to store activations of shape (b,s,h) at various points. However, with sequence parallelism, the maximum activation size is reduced to <d-math>\frac{b \cdot s \cdot h}{tp}</d-math> since we always either split along the sequence or hidden dimensions.</p>
|
753 |
+
|
754 |
+
<p>It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP/SP - believe us, we find it hard to map as well so we made this small table to summarize how the activations (aka <code>hidden_states</code> ) shape change across hidden dimension h and sequence dimension s during a forward pass:</p>
|
755 |
+
|
756 |
+
<table>
|
757 |
+
<thead>
|
758 |
+
<tr>
|
759 |
+
<th>Region</th>
|
760 |
+
<th>TP only</th>
|
761 |
+
<th>TP with SP</th>
|
762 |
+
</tr>
|
763 |
+
</thead>
|
764 |
+
<tbody>
|
765 |
+
<tr>
|
766 |
+
<td>Enter TP (Column Linear)</td>
|
767 |
+
<td>h: sharded (weight_out is sharded)<br>s: full</td>
|
768 |
+
<td>h: sharded (weight_out is sharded)<br>s: <strong>all-gather</strong> to full</td>
|
769 |
+
</tr>
|
770 |
+
<tr>
|
771 |
+
<td>TP Region</td>
|
772 |
+
<td>h: sharded<br>s: full</td>
|
773 |
+
<td>h: sharded<br>s: full</td>
|
774 |
+
</tr>
|
775 |
+
<tr>
|
776 |
+
<td>Exit TP (Row Linear)</td>
|
777 |
+
<td>h: full (weight_out is full + <strong>all-reduce</strong> for correctness)<br>s: full</td>
|
778 |
+
<td>h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)<br>s: <strong>reduce-scatter</strong> to sharded</td>
|
779 |
+
</tr>
|
780 |
+
<tr>
|
781 |
+
<td>SP Region</td>
|
782 |
+
<td>h: full<br>s: full</td>
|
783 |
+
<td>h: full<br>s: sharded</td>
|
784 |
+
</tr>
|
785 |
+
</tbody>
|
786 |
+
</table>
|
787 |
+
|
788 |
+
<p>And for the embedding layer:</p>
|
789 |
+
|
790 |
+
<table>
|
791 |
+
<thead>
|
792 |
+
<tr>
|
793 |
+
<th>Region</th>
|
794 |
+
<th>Vanilla TP</th>
|
795 |
+
<th>TP with SP</th>
|
796 |
+
</tr>
|
797 |
+
</thead>
|
798 |
+
<tbody>
|
799 |
+
<tr>
|
800 |
+
<td>Embedding Layer (Row Linear sharded on vocab)</td>
|
801 |
+
<td>h: full (weight_out is full + <strong>all-reduce</strong> for correctness)<br>s: unchanged</td>
|
802 |
+
<td>h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)<br>s: <strong>reduce-scatter</strong> to sharded</td>
|
803 |
+
</tr>
|
804 |
+
</tbody>
|
805 |
+
</table>
|
806 |
+
|
807 |
+
<p>You can find an example of implementation of both column and row linear TP in picotron:
|
808 |
+
|
809 |
+
<a href="https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py">https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py</a> </p>
|
810 |
+
|
811 |
+
<p>By using sequence parallelism, we can achieve even greater activation memory savings, allowing us to push our batch size and sequence length further than what would be possible with tensor parallelism alone. Let's see what that means for our previous 70B model example:</p>
|
812 |
+
|
813 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
814 |
+
|
815 |
+
<p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
|
816 |
+
|
817 |
+
<p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
|
818 |
+
|
819 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
820 |
+
|
821 |
+
<p>Besides the fact that TP requires communications in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8).</p>
|
822 |
+
|
823 |
+
|
824 |
+
<aside>Overlapping communication with computation for TP is an active area of research, with recent work like Domino <d-cite bibtex-key="wang2024domino"></d-cite> exploring novel techniques to maximize this overlap. For example, Megatron-LM/Nanotron implement a partial overlapping of all-gather with FC1 computation, and we expect to see more innovations in this space as the field continues to evolve.</aside>
|
825 |
+
|
826 |
+
<p>As you might expect, this communication overhead becomes increasingly problematic as we scale up tensor parallelism. To illustrate this, let’s check throughput as we scale TP with SP for a 3B model:</p>
|
827 |
+
|
828 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
829 |
+
<p>Impact of combined Tensor and Sequence Parallelism (TP/SP) on a 3B model’s performance and memory utilization with 4096 seqlen: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.</p>
|
830 |
+
|
831 |
+
<p>Let’s summarize our observations:</p>
|
832 |
+
|
833 |
+
<ul>
|
834 |
+
<li>for both methods we notice the biggest performance drop when we move from TP=8 to TP=16, because that’s when we move from only communicating within a single node (NVLink), to communicating inter-nodes (EFA)</li>
|
835 |
+
<li>the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone</li>
|
836 |
+
<li>the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone</li>
|
837 |
+
</ul>
|
838 |
+
|
839 |
+
<p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
|
840 |
+
|
841 |
+
<p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
|
842 |
+
|
843 |
+
<aside>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to allreduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is a small communication overhead since LayerNorm has relatively few parameters.</aside>
|
844 |
+
|
845 |
+
<p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
|
846 |
+
|
847 |
<h2>Context Parallelism</h2>
|
848 |
|
849 |
<h3>Introducing Context Parallelism</h3>
|
src/bibliography.bib
CHANGED
@@ -403,4 +403,13 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
|
|
403 |
archivePrefix={arXiv},
|
404 |
primaryClass={cs.LG},
|
405 |
url={https://arxiv.org/abs/2205.05198},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
}
|
|
|
403 |
archivePrefix={arXiv},
|
404 |
primaryClass={cs.LG},
|
405 |
url={https://arxiv.org/abs/2205.05198},
|
406 |
+
}
|
407 |
+
@misc{wang2024domino,
|
408 |
+
title={Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping},
|
409 |
+
author={Guanhua Wang and Chengming Zhang and Zheyu Shen and Ang Li and Olatunji Ruwase},
|
410 |
+
year={2024},
|
411 |
+
eprint={2409.15241},
|
412 |
+
archivePrefix={arXiv},
|
413 |
+
primaryClass={cs.DC},
|
414 |
+
url={https://arxiv.org/abs/2409.15241},
|
415 |
}
|
src/index.html
CHANGED
@@ -468,16 +468,125 @@
|
|
468 |
|
469 |
<h2>Data Parallelism</h2>
|
470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
<h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
|
472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
<h4><strong>Second optimization:</strong> Bucketing gradients</h4>
|
474 |
|
475 |
-
<
|
476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
<h3>Revisit global batch size</h3>
|
478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
<h3>Our journey up to now</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
<h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
|
482 |
|
483 |
<h4>Memory usage revisited</h4>
|
@@ -488,12 +597,253 @@
|
|
488 |
|
489 |
<h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
|
490 |
|
|
|
491 |
<h2>Tensor Parallelism</h2>
|
492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
<h3>Tensor Parallelism in a Transformer Block</h3>
|
494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
<h3>Sequence Parallelism</h3>
|
496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
<h2>Context Parallelism</h2>
|
498 |
|
499 |
<h3>Introducing Context Parallelism</h3>
|
|
|
468 |
|
469 |
<h2>Data Parallelism</h2>
|
470 |
|
471 |
+
<p>The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism. </p>
|
472 |
+
|
473 |
+
<p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
|
474 |
+
|
475 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
476 |
+
|
477 |
+
<p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
|
478 |
+
|
479 |
+
<aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].</aside>
|
480 |
+
|
481 |
+
<p>TODO: embed naive DP: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60</a></p>
|
482 |
+
|
483 |
+
<p>TODO: embed bucket DP: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171</a></p>
|
484 |
+
|
485 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
486 |
+
|
487 |
+
<p>A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is <strong>A BIG NO!</strong> Because we don’t want our GPUs to stay idle while communication is happening.</p>
|
488 |
+
|
489 |
+
<p>Instead we should try to overlap communication and computation whenever possible so that they happen at the same time as much as possible.</p>
|
490 |
+
|
491 |
+
<p>Let’s see three optimizations that are done in practice for this! </p>
|
492 |
+
|
493 |
<h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
|
494 |
+
|
495 |
+
<p>The main drawback of the naive DDP approach we’ve just described is that after the backward pass (<em>computation</em>), we have to wait for gradient synchronization (<em>communication</em>) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!</p>
|
496 |
+
|
497 |
+
<p>As shown in the figure above, the gradients (red boxes) for a layer can be gathered and summed even before the gradients from earlier layers (red boxes to the left) have been computed. For example, as soon as the backward pass of the last layer is complete (last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.</p>
|
498 |
+
|
499 |
+
<p>This can be achieved in pytorch by attaching an <em>all-reduce hook function</em> to each parameter. An all-reduce operation is triggered as soon as the gradient for that parameter is ready, while gradients for other parameters are still being computed. This approach overlaps most of the all-reduce operations with gradient calculations, thereby improving efficiency. Here's a simple function to attach a hook:</p>
|
500 |
+
|
501 |
+
<d-code block language="python">
|
502 |
+
def register_backward_hook(self, hook):
|
503 |
+
"""
|
504 |
+
Registers a backward hook for all parameters of the model that
|
505 |
+
require gradients.
|
506 |
+
"""
|
507 |
+
for p in self.module.parameters():
|
508 |
+
if p.requires_grad is True:
|
509 |
+
p.register_post_accumulate_grad_hook(hook)</d-code>
|
510 |
+
|
511 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png"/></p>
|
512 |
+
|
513 |
+
<p>Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. </p>
|
514 |
+
|
515 |
+
<p>This is our first example of “<em>overlapping computation and communication</em>” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency. Let's have a look how we can further improve the DP efficiency!</p>
|
516 |
+
|
517 |
+
|
518 |
<h4><strong>Second optimization:</strong> Bucketing gradients</h4>
|
519 |
|
520 |
+
<p>We can even go further with optimizing DP. For a given number of parameters to synchronize, GPU operations like collective communications are often more efficient when performing few calls on large tensors rather than many calls on smaller tensors. Therefore, instead of performing independent all-reduce for each gradient, we can group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket. Think of it like packing items into boxes before shipping—it's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.</p>
|
521 |
+
|
522 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
523 |
+
|
524 |
+
<h4><strong>Third optimization: </strong>Interplay with gradient accumulation</h4>
|
525 |
+
|
526 |
+
<p>As we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with <code>optimizer.step()</code>. When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.</p>
|
527 |
+
|
528 |
+
<p>In a naive version, an all-reduce operation is automatically triggered after each backward pass during the accumulation, which is sub-optimal as a single reduce after the final step would have the same effect while reducing overhead.</p>
|
529 |
+
|
530 |
+
<p>In PyTorch, this is typically solved by adding a <a href="https://github.com/pytorch/pytorch/blob/5ea67778619c31b13644914deef709199052ee55/torch/nn/parallel/distributed.py#L1408-L1435"><code>model.no_sync()</code></a> decorator, which disables gradient synchronization, on the backward passes which don’t need reduction.</p>
|
531 |
+
|
532 |
+
<aside>When performing communication operations, tensors must be contiguous in memory. To avoid redundant memory copies during communication, ensure that tensors that will be communicated are stored contiguously in memory. Sometimes we need to allocate additional continuous buffers of the size of activations or model parameters specifically for communication, which contributes to the peak memory usage during training.</aside>
|
533 |
+
|
534 |
<h3>Revisit global batch size</h3>
|
535 |
+
<p>Let’s update our batch size equation with our newly learned Data Parallelism and Gradient Accumulation parameters:</p>
|
536 |
+
|
537 |
+
<d-math block>
|
538 |
+
bs = gbs = mbs \times grad\_acc
|
539 |
+
</d-math>
|
540 |
+
<p>Where <d-math>grad\_acc</d-math> is the number of gradient accumulation steps and DP is the number of parallel instances used for data parallelism.</p>
|
541 |
+
|
542 |
+
<p>Given a targeted global batch size, we can thus trade gradient accumulation steps for data-parallel processes to speed up training. In practice, people tend to maximize the number of data-parallel nodes (DP) over gradient accumulation as much as possible since it's inherently parallel, unlike the sequential nature of gradient accumulation. Gradient accumulation is then added on top of data parallelism to achieve the target global batch size when scaling data parallelism alone is not sufficient before you run out of GPUs.</p>
|
543 |
+
|
544 |
+
<aside>A good resource for further reading on Data Parallelism is <a href="https://siboehm.com/articles/22/data-parallel-training">https://siboehm.com/articles/22/data-parallel-training</a>.
|
545 |
+
</aside>
|
546 |
+
|
547 |
+
<p>Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 4 more dimensions).</p>
|
548 |
+
|
549 |
<h3>Our journey up to now</h3>
|
550 |
+
<p>Let’s quickly summarize what we’ve seen up to now and how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:</p>
|
551 |
+
|
552 |
+
<ol>
|
553 |
+
<li>We should first determine the best (global) batch size in tokens (<code>GBST</code>) either by consulting literature or running experiments measuring model convergence.</li>
|
554 |
+
<li>We then select a sequence length for training, again by either consulting literature or running experiments. Generally, 2-8k tokens work reliably well for the evaluations we have today (we won’t dive in training recipes here but teams usually increase the sequence at the end of the training, adding some longer-context data samples in the mix to reach the longer context size of today).</li>
|
555 |
+
<li>We now know the batch size (gbs). We can find the maximum local batch size (mbs) on a single GPU by increasing the local batch size until we run out of memory.</li>
|
556 |
+
<li>Finally, we determine the number of available GPUs for our target DP. The ratio of GBS to DP gives us the remaining number of gradient accumulation steps needed for the desired GBS. </li>
|
557 |
+
</ol>
|
558 |
+
|
559 |
+
<aside>For instance DeepSeek and Llama models are trained with a 4k tokens sequence length during the main pretraining phase.</aside>
|
560 |
+
|
561 |
+
<aside>The reason 2-8k work well for pretraining is that documents that are longer are very rare on the web. See this <a href="https://www.harmdevries.com/post/context-length/">Harm’s blogpost</a> for a detailed analysis.
|
562 |
+
</aside>
|
563 |
+
|
564 |
+
<p>If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs a.k.a GPU-rich 🤑 (!), we can either choose to not use all our GPUs, explore a larger global batch size or test if a lower MBS will speed up training. In the latter case we’ll end up prioritizing throughput over individual GPU compute efficiency, using a smaller MBS than possible in order to speed up training.</p>
|
565 |
+
|
566 |
+
<p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!</p>
|
567 |
+
|
568 |
+
<aside>Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by <em>ring latency</em> (time required for a signal to propagate once around the ring) **which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
|
569 |
+
</aside>
|
570 |
+
|
571 |
+
<p>While data parallelism cleverly overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. As we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly. The end result? We get less and less efficient returns from each additional GPU we add to the system:</p>
|
572 |
+
|
573 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png"/></p>
|
574 |
+
|
575 |
+
<p>As expected, we can also see that the memory usage per GPU is not affected by adding more DP ranks for training.</p>
|
576 |
+
|
577 |
+
<p><strong>We’ve explored data parallelism, our first (simple) strategy to scale training across more GPUs. It works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!</strong></p>
|
578 |
|
579 |
+
<p>The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs<em>=1)</em> into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated: </p>
|
580 |
+
|
581 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
582 |
+
|
583 |
+
<aside>Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)</aside>
|
584 |
+
|
585 |
+
<p>Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some of these tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices!</p>
|
586 |
+
|
587 |
+
<p>There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined! The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!</p>
|
588 |
+
|
589 |
+
|
590 |
<h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
|
591 |
|
592 |
<h4>Memory usage revisited</h4>
|
|
|
597 |
|
598 |
<h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
|
599 |
|
600 |
+
|
601 |
<h2>Tensor Parallelism</h2>
|
602 |
+
|
603 |
+
<p>So we have sharded the model’s parameters, gradients and optimizers states with ZeRO but we hit a limit once activation memory overtakes our memory budget. Welcome Tensor Parallelism (TP), a method which shards weights, gradients, and optimizers states as well as activations and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how Tensor Parallel works with simple matrix multiplications.</p>
|
604 |
+
|
605 |
+
<p>Tensor Parallelism leverages the mathematical properties of matrix multiplication <d-math>A \times B</d-math>. To understand how it works, let's examine two fundamental equations that make this parallelization possible:</p>
|
606 |
+
|
607 |
+
<d-math block>
|
608 |
+
\begin{aligned}
|
609 |
+
&\text{1.} \quad A\cdot B = A \cdot \begin{bmatrix} B_1 & B_2 & \cdots \end{bmatrix} = \begin{bmatrix} AB_1 & AB_2 & \cdots \end{bmatrix} \\
|
610 |
+
&\text{2.} \quad A\cdot B =\begin{bmatrix} A_1 & A_2 & \cdots \end{bmatrix} \begin{bmatrix} B_1 \\ B_2 \\ \vdots \end{bmatrix} = \sum_{i=1}^n A_i B_i
|
611 |
+
\end{aligned}
|
612 |
+
</d-math>
|
613 |
+
|
614 |
+
<p>This means that we can compute matrix product by either 1) multiplying each column of <d-math>B</d-math> individually or 2) multiplying each row individually and combining the results. In a neural network, the matrix multiplication is more often represented in the following format: <d-math>X \times W</d-math>, where:</p>
|
615 |
+
|
616 |
+
<ul>
|
617 |
+
<li>X represents the input or activation values</li>
|
618 |
+
<li>W represents the weight of the <code>nn.Linear</code></li>
|
619 |
+
</ul>
|
620 |
+
|
621 |
+
<p>In practice a small example of the operation looks like this:</p>
|
622 |
+
|
623 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
624 |
+
|
625 |
+
<p>Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.</p>
|
626 |
+
|
627 |
+
<p>Our first option is to use column-wise sharding (also called <strong><em>column-linear</em></strong>): We'll copy the complete input matrices to each worker, requiring an operation called <strong><em>broadcast</em></strong>, and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an <strong><em>all-gather</em></strong> operation.</p>
|
628 |
+
|
629 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
630 |
+
|
631 |
+
<p>The second option is called row-wise sharding (also called <strong><em>row-linear</em></strong>): As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, which needs a <strong><em>scatter</em></strong> operation rather than a broadcast as used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, thus requiring an all-reduce operation in this scenario.</p>
|
632 |
+
|
633 |
+
<p>We see here our fourth distributed primitive: <strong><em>scatter</em></strong>!</p>
|
634 |
+
|
635 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
636 |
+
|
637 |
<h3>Tensor Parallelism in a Transformer Block</h3>
|
638 |
|
639 |
+
<p>To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks : Feedforward layers (MLP) and Multi-Head Attention (MHA). We can apply tensor parallelism to both.</p>
|
640 |
+
|
641 |
+
<p>The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks.</p>
|
642 |
+
|
643 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
644 |
+
|
645 |
+
<p>Now that we’ve found the most efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA).</p>
|
646 |
+
|
647 |
+
<p>We can generally follow a similar approach where Q, K, and V matrices are split in a column-parallel fashion, and the output projection is split along the row dimension. With multi-head attention, the column-parallel approach has a very natural interpretation: each worker computes the attention for an individual or a subset of heads. The same approach works as well for <a href="https://arxiv.org/abs/1911.02150"><strong><em>multi-query</em></strong> (MQA)</a> or <a href="https://arxiv.org/abs/2305.13245"><strong><em>grouped query attention</em></strong> (GQA)</a> where key and values are shared between queries. </p>
|
648 |
+
|
649 |
+
<p>It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.</p>
|
650 |
+
|
651 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
652 |
+
|
653 |
+
<p>Finally note that there is a tradeoff in terms of communication as we’ve added several distributed communication primitive directly in the computation path of our model. At the difference of ZeRO where we could prefetch, it can be harder to make these communication fully overlap with computations. </p>
|
654 |
+
|
655 |
+
<p><img alt="Forward pass in Tensor Parallelism" src="/assets/images/placeholder.png" /></p>
|
656 |
+
|
657 |
+
<p>Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This <em>exposed communication</em> overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied. </p>
|
658 |
+
|
659 |
+
<p>Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, it introduces significant communication requirements that heavily depend on the network infrastructure. The inability to hide this particular AllReduce behind computation means it directly adds to the critical path of forward propagation.</p>
|
660 |
+
|
661 |
+
<p><img alt="Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training." src="/assets/images/placeholder.png" /></p>
|
662 |
+
|
663 |
+
<p>Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.</p>
|
664 |
+
|
665 |
+
<p>In practice, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. As shown in the throughput plot above, we observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. This illustrates how communication costs can dominate at higher degrees of parallelism.</p>
|
666 |
+
|
667 |
+
<p>However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients, optimizer states and activations (to some extent) across GPUs. Let's examine this effect on a 70B parameter model:</p>
|
668 |
+
|
669 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
670 |
+
|
671 |
+
<p>As we can see, increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU. While tensor parallelism does help reduce activation memory in attention and feedforward layers by sharding the matrix multiplications across GPUs, we don't get the full memory benefits we could. This is because operations like layer normalization and dropout still require gathering the full activations on each GPU, partially negating the memory savings. We can do better by finding ways to parallelize these remaining operations as well.</p>
|
672 |
+
|
673 |
+
<aside>One interesting note about layer normalization in tensor parallel training - since each TP rank sees the same activations after the all-gather, the layer norm weights don't actually need an all-reduce to sync their gradients after the backward pass. They naturally stay in sync across ranks. However, for dropout operations, we must make sure to sync the random seed across TP ranks to maintain deterministic behavior.
|
674 |
+
</aside>
|
675 |
+
|
676 |
+
<p>This raises an interesting question - could we extend tensor parallelism to these remaining operations as well? Indeed, it's possible to parallelize layer norm, dropout and other operations too, which we'll explore next.</p>
|
677 |
+
|
678 |
<h3>Sequence Parallelism</h3>
|
679 |
|
680 |
+
<p>In regions where we apply tensor parallelism (TP), like attention and feedforward layers, each GPU only needs to operate on a portion of the hidden dimension since the weights are sharded. However, operations like layer norm or dropout (which is not used a lot anymore in LLM) require access to the full hidden dimension to compute correctly.</p>
|
681 |
+
|
682 |
+
<p>Rather than gathering the full hidden dimension on each GPU (which would defeat the memory benefits of TP), we can instead shard these operations along the sequence length dimension. This approach is called <strong>sequence parallelism (SP)</strong>.</p>
|
683 |
+
|
684 |
+
<aside>Note that the term Sequence Parallelism is a bit overloaded: the Sequence Parallelism in this section is tightly coupled to Tensor Parallelism and applies to dropout and layer norm operation. However, when we will move to longer sequences the attention computation will become a bottleneck, which calls for techniques such as Ring-Attention, which are sometimes also called <em>Sequence Parallelism</em> but we’ll refer to them as <em>Context Parallelism</em> to differentiate the two approaches. So each time you see sequence parallelism, remember that it is used together with tensor parallelism (in contrast to context parallelism, which can be used independently).</aside>
|
685 |
+
|
686 |
+
<p>Sequence parallelism (SP) involves splitting the activations and computations for the parts of the model not handled by tensor parallelism (TP) such as Dropout and LayerNorm, but along the input sequence dimension rather than across hidden dimension. This is needed because these operations require access to the full hidden dimension to compute correctly. For example, LayerNorm needs the full hidden dimension to compute mean and variance:</p>
|
687 |
+
|
688 |
+
<d-math block>
|
689 |
+
\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
|
690 |
+
</d-math>
|
691 |
+
|
692 |
+
<p>where <d-math>\mu = \text{mean}(x)</d-math> and <d-math>\sigma^2 = \text{var}(x)</d-math> are computed across hidden dimension <d-math>h</d-math>.</p>
|
693 |
+
|
694 |
+
<p>So even though these operations are computationally cheap, they still require significant activation memory since they need the complete hidden dimension. SP allows us to shard this <strong>memory</strong> burden across GPUs by splitting along the sequence dimension instead.</p>
|
695 |
+
|
696 |
+
<p>In practice we’ll go from the left diagram to the right:</p>
|
697 |
+
|
698 |
+
<p><img alt=" in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
|
699 |
+
in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
|
700 |
+
SP region needs full hidden_dim" src="/assets/images/placeholder.png" /></p>
|
701 |
+
|
702 |
+
<p>in forward: f = no-op ; f<em> = all-reduce ; g = all-gather ; g</em> = reduce-scatter in backward: f = all-reduce ; f<em> = no-op ; g = reduce-scatter ; g</em> = all-gather SP region needs full hidden_dim</p>
|
703 |
+
|
704 |
+
<p>The diagram shows how we transition between tensor-parallel and sequence-parallel regions using different collective operations (labeled "f" and "g"). The key challenge is managing these transitions efficiently while keeping memory usage low and maintaining correctness.</p>
|
705 |
+
|
706 |
+
<p>In the forward pass:</p>
|
707 |
+
<ul>
|
708 |
+
<li>"f" is a no-op (no operation) because activations are already duplicated across ranks</li>
|
709 |
+
<li>"f*" is an all-reduce to synchronize activations and ensure correctness</li>
|
710 |
+
</ul>
|
711 |
+
<p>In the backward pass:</p>
|
712 |
+
<ul>
|
713 |
+
<li>"f*" is a no-op because gradients are already duplicated across ranks</li>
|
714 |
+
<li>"f" is an all-reduce to synchronize gradients</li>
|
715 |
+
</ul>
|
716 |
+
|
717 |
+
<p>These operations "f" and "f<em>" are called </em><em>conjugate</em>* pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.</p>
|
718 |
+
|
719 |
+
<p>For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
|
720 |
+
|
721 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
722 |
+
|
723 |
+
<p>So what is actually happening here? As a famous LLM would say, let’s take it step-by-step:</p>
|
724 |
+
|
725 |
+
<p><strong>Initial LayerNorm (SP Region)</strong></p>
|
726 |
+
<ul>
|
727 |
+
<li>Input tensors X1<em> and X2</em> (b,s/2,h) enter LayerNorm, already split across sequence dimension</li>
|
728 |
+
<li>Each GPU computes LayerNorm independently on its sequence chunk and give Y1<em> and Y2</em></li>
|
729 |
+
</ul>
|
730 |
+
<p><strong>First Transition (SP → TP)</strong></p>
|
731 |
+
<ul>
|
732 |
+
<li>"g" operation (all-gather) combines Y1<em> and Y2</em> back to full sequence length</li>
|
733 |
+
<li> Restores Y (b,s,h) since column linear layer needs full hidden dimension h</li>
|
734 |
+
</ul>
|
735 |
+
<p><strong>First Linear Layer (TP Region)</strong></p>
|
736 |
+
<ul>
|
737 |
+
<li>A1 is a column-linear layer, so it splits Y along the hidden dimension</li>
|
738 |
+
<li>GeLU is applied independently on each GPU</li>
|
739 |
+
<li>Z1* is (b,s,h/2)</li>
|
740 |
+
</ul>
|
741 |
+
<p><strong>Second Linear Layer (TP Region)</strong></p>
|
742 |
+
<ul>
|
743 |
+
<li>B1 is a row-linear layer, so it restores the hidden dimension</li>
|
744 |
+
<li>W1 is (b,s,h)</li>
|
745 |
+
</ul>
|
746 |
+
<p><strong>Final Transition (TP → SP)</strong></p>
|
747 |
+
<ul>
|
748 |
+
<li>"g*" operation (reduce-scatter) which reduces for previous row-linear correctness while scattering along sequence dimension</li>
|
749 |
+
<li>W1* is (b,s/2,h)</li>
|
750 |
+
</ul>
|
751 |
+
|
752 |
+
<p>A key advantage of sequence parallelism is that it reduces the maximum activation size we need to store. In tensor parallelism alone, we had to store activations of shape (b,s,h) at various points. However, with sequence parallelism, the maximum activation size is reduced to <d-math>\frac{b \cdot s \cdot h}{tp}</d-math> since we always either split along the sequence or hidden dimensions.</p>
|
753 |
+
|
754 |
+
<p>It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP/SP - believe us, we find it hard to map as well so we made this small table to summarize how the activations (aka <code>hidden_states</code> ) shape change across hidden dimension h and sequence dimension s during a forward pass:</p>
|
755 |
+
|
756 |
+
<table>
|
757 |
+
<thead>
|
758 |
+
<tr>
|
759 |
+
<th>Region</th>
|
760 |
+
<th>TP only</th>
|
761 |
+
<th>TP with SP</th>
|
762 |
+
</tr>
|
763 |
+
</thead>
|
764 |
+
<tbody>
|
765 |
+
<tr>
|
766 |
+
<td>Enter TP (Column Linear)</td>
|
767 |
+
<td>h: sharded (weight_out is sharded)<br>s: full</td>
|
768 |
+
<td>h: sharded (weight_out is sharded)<br>s: <strong>all-gather</strong> to full</td>
|
769 |
+
</tr>
|
770 |
+
<tr>
|
771 |
+
<td>TP Region</td>
|
772 |
+
<td>h: sharded<br>s: full</td>
|
773 |
+
<td>h: sharded<br>s: full</td>
|
774 |
+
</tr>
|
775 |
+
<tr>
|
776 |
+
<td>Exit TP (Row Linear)</td>
|
777 |
+
<td>h: full (weight_out is full + <strong>all-reduce</strong> for correctness)<br>s: full</td>
|
778 |
+
<td>h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)<br>s: <strong>reduce-scatter</strong> to sharded</td>
|
779 |
+
</tr>
|
780 |
+
<tr>
|
781 |
+
<td>SP Region</td>
|
782 |
+
<td>h: full<br>s: full</td>
|
783 |
+
<td>h: full<br>s: sharded</td>
|
784 |
+
</tr>
|
785 |
+
</tbody>
|
786 |
+
</table>
|
787 |
+
|
788 |
+
<p>And for the embedding layer:</p>
|
789 |
+
|
790 |
+
<table>
|
791 |
+
<thead>
|
792 |
+
<tr>
|
793 |
+
<th>Region</th>
|
794 |
+
<th>Vanilla TP</th>
|
795 |
+
<th>TP with SP</th>
|
796 |
+
</tr>
|
797 |
+
</thead>
|
798 |
+
<tbody>
|
799 |
+
<tr>
|
800 |
+
<td>Embedding Layer (Row Linear sharded on vocab)</td>
|
801 |
+
<td>h: full (weight_out is full + <strong>all-reduce</strong> for correctness)<br>s: unchanged</td>
|
802 |
+
<td>h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)<br>s: <strong>reduce-scatter</strong> to sharded</td>
|
803 |
+
</tr>
|
804 |
+
</tbody>
|
805 |
+
</table>
|
806 |
+
|
807 |
+
<p>You can find an example of implementation of both column and row linear TP in picotron:
|
808 |
+
|
809 |
+
<a href="https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py">https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py</a> </p>
|
810 |
+
|
811 |
+
<p>By using sequence parallelism, we can achieve even greater activation memory savings, allowing us to push our batch size and sequence length further than what would be possible with tensor parallelism alone. Let's see what that means for our previous 70B model example:</p>
|
812 |
+
|
813 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
814 |
+
|
815 |
+
<p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
|
816 |
+
|
817 |
+
<p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
|
818 |
+
|
819 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
820 |
+
|
821 |
+
<p>Besides the fact that TP requires communications in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8).</p>
|
822 |
+
|
823 |
+
|
824 |
+
<aside>Overlapping communication with computation for TP is an active area of research, with recent work like Domino <d-cite bibtex-key="wang2024domino"></d-cite> exploring novel techniques to maximize this overlap. For example, Megatron-LM/Nanotron implement a partial overlapping of all-gather with FC1 computation, and we expect to see more innovations in this space as the field continues to evolve.</aside>
|
825 |
+
|
826 |
+
<p>As you might expect, this communication overhead becomes increasingly problematic as we scale up tensor parallelism. To illustrate this, let’s check throughput as we scale TP with SP for a 3B model:</p>
|
827 |
+
|
828 |
+
<p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
|
829 |
+
<p>Impact of combined Tensor and Sequence Parallelism (TP/SP) on a 3B model’s performance and memory utilization with 4096 seqlen: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.</p>
|
830 |
+
|
831 |
+
<p>Let’s summarize our observations:</p>
|
832 |
+
|
833 |
+
<ul>
|
834 |
+
<li>for both methods we notice the biggest performance drop when we move from TP=8 to TP=16, because that’s when we move from only communicating within a single node (NVLink), to communicating inter-nodes (EFA)</li>
|
835 |
+
<li>the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone</li>
|
836 |
+
<li>the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone</li>
|
837 |
+
</ul>
|
838 |
+
|
839 |
+
<p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
|
840 |
+
|
841 |
+
<p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
|
842 |
+
|
843 |
+
<aside>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to allreduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is a small communication overhead since LayerNorm has relatively few parameters.</aside>
|
844 |
+
|
845 |
+
<p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
|
846 |
+
|
847 |
<h2>Context Parallelism</h2>
|
848 |
|
849 |
<h3>Introducing Context Parallelism</h3>
|