lvwerra HF staff commited on
Commit
ad6c00b
·
1 Parent(s): 89bae42
Files changed (5) hide show
  1. blog-export.md +161 -135
  2. dist/bibliography.bib +9 -0
  3. dist/index.html +355 -5
  4. src/bibliography.bib +9 -0
  5. src/index.html +355 -5
blog-export.md CHANGED
@@ -28,7 +28,7 @@ An overview of the over 4000 experiments across all Llama architectures where ea
28
 
29
  As you can see, there’s a lot of ground to be covered. Before getting into the trenches of distributed training let’s take a quick high level look on we’ll cover in the post.
30
 
31
- # TL;DR
32
 
33
  This book is very extensive so we decide to start with a very general overview of how you can think about distributed training. At a high level, the key challenge in scaling LLM training is to make a training step (forward/backward/optimizer step) with a large batch size the fastest possible.
34
 
@@ -55,7 +55,7 @@ But let’s not get too much ahead of our self and scale progressively. To guide
55
 
56
  Now that we nailed a few key concept and terms let’s get started by revisiting the basic training steps of an LLM!
57
 
58
- # First Steps: Training on one GPU
59
 
60
  Let’s start by quickly reviewing the very basics of model training before we start to scale to many GPUs. When a model is trained on a single GPU, the training typically consists of three steps:
61
 
@@ -101,7 +101,7 @@ A sweet spot for recent LLM training is typically on the order of 4-60 million t
101
 
102
  Let’s start by quickly understanding what led to our out-of-memory issue in the first place. This will help us gain some useful intuitions for later.
103
 
104
- ## Memory usage in Transformers
105
 
106
  When training a neural network model, one store several items in memory:
107
 
@@ -143,13 +143,13 @@ For a simple transformer LLM the number of parameters is given by the [following
143
 
144
  $$
145
 
146
- N = h*v + L * (12 * h^2 + 13*h) + 2*h
147
  $$
148
 
149
  > Note: we excluded the positional embedding count as rotary embeddings are not learned.
150
  >
151
 
152
- In that equation, $h$ is the hidden dimension, $v$ the vocabulary size, and $L$ the number of layers in the model. Note that looking at the equation we can see that the term that will dominate at large hidden dimensions is the $h^2$ term since its the only one growing quadratically as we scale the parameters.
153
 
154
  Memory requirements for the parameters and gradients are simply the number of parameters multiplied by the number of bytes per parameter. In good old-fashioned full precision (FP32) training both parameters and gradients require 4 bytes while the optimizer, if we use Adam, requires the momentum and variance to be stored, which adds another two 4 bytes per parameter. In summary:
155
 
@@ -173,7 +173,7 @@ m_{params\_fp32} = 4 * N \\
173
  m_{opt} = (4+4) * N
174
  $$
175
 
176
- > Some librarie store grads in fp32 which would require an additional $m_{params\_fp32} = 4 * N$ memory. This is done for example in nanotron, because `bf16` is lossy for smaller values and we always prioritize stability. See https://github.com/microsoft/DeepSpeed/issues/1773 for more information.
177
  >
178
 
179
  Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as having the model which does the forward/backward in half precision it allows us to (1) use optimized lower precision operations on the GPU which are faster and (2) reduces the activation memory requirements during the forward pass.
@@ -203,7 +203,7 @@ m_{act} = L* seq * bs * h * (34 + \frac{5*n_{heads}*seq}{h})
203
 
204
  $$
205
 
206
- Here L is the number of layers, $seq$ the sequence length, $bs$ the batch size in samples, $h$ the hidden dimension of the model and $n_{heads}$ the number of heads.
207
 
208
  For the exact numbers derivation, you can follow this [NVIDIA pape](https://arxiv.org/pdf/2205.05198)r on recomputation, it essentially requires you to do some accounting of all the sizes of intermediate activations between each operation.
209
 
@@ -219,7 +219,7 @@ Is there a way to tame this “activation explosion”? Good question, reader!
219
 
220
  It’s time to explain our first technique – called ***activation recomputation**–* ****which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.
221
 
222
- ## **Activation recomputation**
223
 
224
  The general idea behind ***activation recomputation** –*also called ***gradient checkpointing*** or ***rematerialization**– *****is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:
225
 
@@ -252,7 +252,7 @@ Now that we’ve learned about recomputation, we can tame the activations memory
252
 
253
  However, activations still bears a linear dependance on the batch size and all our profiles in the barplots above were using `bs=1` so as we move to larger batch sizes it might become an issue again. Do not despair as we have a second tool in our box - ***gradient accumulation*** to the rescue!
254
 
255
- ## Gradient accumulation
256
 
257
  Now that we’ve used activation recomputation to fit our model with a small batch size on a single GPU, we still need to reach our target batch size, let’s say 1M tokens (see our earlier discussion on optimal batch size). Gradient accumulation is a very straightforward method to avoid memory explosion when doing this.
258
 
@@ -281,9 +281,9 @@ But if you’ve carefully followed, you probably noticed that the forward/backwa
281
 
282
  Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called ***data parallelism** which is just a parallel version of gradient accumulation*.
283
 
284
- TODO: intro for this
285
 
286
- ## torch.profiler
287
 
288
  ![**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%207.png)
289
 
@@ -297,7 +297,7 @@ In this naive approach we see a long AllReduce operation (stream 28) happening t
297
 
298
  **Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**
299
 
300
- # Data Parallelism
301
 
302
  The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism.
303
 
@@ -310,11 +310,6 @@ This involves our first “distributed communication” primitive: [**All-Reduce
310
  > If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].
311
  >
312
 
313
- TODO: bucket grads to avoid multiple comms
314
- TODO: show comms overlap
315
-
316
- TODO: any comms requires at least a contiguous buffer to do comms → TIP: make sure tensors that’ll be communicated are contiguous in memory to avoid redundant memory copies
317
-
318
  TODO: embed naive DP: [https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60](https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60)
319
 
320
  TODO: embed bucket DP: [https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171](https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171)
@@ -327,7 +322,7 @@ Instead we should try to overlap communication and computation whenever possible
327
 
328
  Let’s see three optimizations that are done in practice for this!
329
 
330
- ### **First optimization:** Overlap gradient synchronization with backward pass
331
 
332
  The main drawback of the naive DDP approach we’ve just described is that after the backward pass (*computation*), we have to wait for gradient synchronization (*communication*) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!
333
 
@@ -352,7 +347,7 @@ Overlapping computation and communication reduces the time spent waiting for gra
352
 
353
  This is our first example of “*overlapping computation and communication*” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency.
354
 
355
- ### **Second optimization:** Bucketing gradients
356
 
357
  But we can even go further. For a given number of parameters to synchronize, GPU operations like collective communications are often more efficient when performing few calls on large tensors rather than many calls on smaller tensors. Therefore, instead of performing independent all-reduce for each gradient, we can group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket. Think of it like packing items into boxes before shipping—it's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.
358
 
@@ -362,7 +357,7 @@ The selected bucket size will be a key factor in determining the efficiency of D
362
 
363
  [TODO: benchmark all reduce with different size / bucket size results ?]
364
 
365
- ### **Third optimization: I**nterplay with gradient accumulation
366
 
367
  As we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with `optimizer.step()`. When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.
368
 
@@ -370,7 +365,11 @@ In a naive version, an all-reduce operation is automatically triggered after eac
370
 
371
  In PyTorch, this is typically solved by adding a [`model.no_sync()`](https://github.com/pytorch/pytorch/blob/5ea67778619c31b13644914deef709199052ee55/torch/nn/parallel/distributed.py#L1408-L1435) decorator, which disables gradient synchronization, on the backward passes which don’t need reduction.
372
 
373
- ## Revisit global batch size
 
 
 
 
374
 
375
  Let’s update our batch size equation with our newly learned Data Parallelism and Gradient Accumulation parameters:
376
 
@@ -387,7 +386,7 @@ Given a targeted global batch size, we can thus trade gradient accumulation step
387
 
388
  Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 3 more dimensions).
389
 
390
- ## Our journey up to now
391
 
392
  Let’s quickly summarize what we’ve seen up to now and how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:
393
 
@@ -406,26 +405,31 @@ If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs
406
 
407
  Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!
408
 
409
- > Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by ring latency which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
410
  >
411
 
412
  TODO: We’re gaining overall throughput but losing efficiency as we scale DP too much
413
 
414
  ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2013.png)
415
 
416
- **We’ve explored data parallelism, our first (simple) strategy to scale training across more GPUs. It works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!**
 
 
417
 
418
- The keen reader have already probably notes however that this assumes that we can fit at least one input sample forward pass (mbs*=1)* into our GPU memory (with activation recomputation if needed).
419
 
420
- This is not always the case! As we’ve seen earlier larger models often don’t fit into a single GPU, even with activation recomputations activated.
421
 
422
  ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2014.png)
423
 
 
 
 
424
  Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some of these tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices!
425
 
426
  There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined! The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!
427
 
428
- ## ZeRO (**Ze**ro **R**edundancy **O**ptimizer)
429
 
430
  In this section we will introduce DeepSpeed ZeRO (**Ze**ro **R**edundancy **O**ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.
431
 
@@ -447,7 +451,7 @@ ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer st
447
 
448
  Let’s have a closer look how much we can save with the partitioning of each ZeRO stage!
449
 
450
- ### Memory usage revisited
451
 
452
  Let’s first recap the memory usage of optimizer states, gradients, and parameters during a standard training. Let’s define the number of our model's parameters as $\Psi$ (previously N but here we use the original ZeRO notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:
453
 
@@ -466,7 +470,7 @@ Memory consumption of DP and three stages of Zero-DP. $\Psi$ denotes number of p
466
 
467
  Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.
468
 
469
- ### ZeRO-1: Partitioning Optimizer States
470
 
471
  In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?
472
 
@@ -478,24 +482,26 @@ This explains the memory formula of $2\Psi + 2\Psi + \frac{k\Psi}{N_d}$ that we
478
 
479
  - Forward pass with all bf16 parameters (but different microbatches across DP ranks)
480
  - Backward pass with all gradients (but different microbatches across DP ranks)
481
- - Perform an reduce scatter on the gradients (reduce scatter is 2 times faster than all reduce!)
482
  - Each replica perform an optimizer step (has only 1/$N_d$ optimizer states) updates only on 1/$N_d$ of fp32 parameters, and then 1/$N_d$ of bf16 parameters
483
  - [New operation in ZeRO, not in vanilla DP] Perform an all-gather of bf16 parameters to send missing slices back to each replica
484
 
485
  ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2016.png)
486
 
 
 
487
  If you've been following along, you'll recall from vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of bf16 parameters. There are two main strategies for this:
488
 
489
- 1. Right after optimizer step: We can initiate the all-gather immediately after the optimizer updates the parameters. This allows the communication to potentially overlap with other post-optimization operations.
490
- 2. Right before forward: We can delay the all-gather until just before we need the parameters for the next forward pass. This approach gives us more flexibility to overlap with any computation happening between training steps.
491
 
492
  But unfortunately these techniques are not as evident to implement as they seem and require sophisticated use of hooks / bucketing. In practice we can just use Zero3 / FSDP implementation where the FSDPUnit is the entire model, more details about this later..
493
 
494
- ### ZeRO-2: Adding **Gradient Partitioning**
495
 
496
  In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates $\frac{1}{N_d}$ of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place since only a subset is needed for the optimization step.
497
 
498
- → During the backward pass, instead of performing an all-reduce over the gradients, we can therefore perform a ***reduce-scatter [TODO: add link]*** operation! *(yay, a third communication primitive!)* Where we only spread the $\frac{1}{N_d}$ gradients needed in memory, thus saving more memory compared to ZeRO-1
499
 
500
  > In case of FP32 gradient accumulation, we only need to keep $\frac{1}{N_d}$ fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the $\frac{1}{N_d}$ fp32_grads.
501
  >
@@ -504,14 +510,14 @@ In ZeRO-1 the optimizer states have been partitioned, which means that each repl
504
 
505
  It’s easy to see now that sharding the gradients leads to to $2\Psi + \frac{2\Psi+k\Psi}{N_d}$ and as $N_d$ is increased we can save up to 8x memory over the baseline. In terms of communication the same process applies as for ZeRO-1, with the only difference that we communicate and release on the fly. In total, ZeRO-2 is thus also equivalent to vanilla DP training w.r.t. communication.
506
 
507
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2017.png)
508
 
509
  > Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option. The reason some distributed training frameworks don’t support it is that gradient sharding may interfere with and make more complex other parallel strategies we discussed later.
510
  >
511
 
512
  Now that we’ve sharded gradients as well, we are we done? Or can we keep getting away with this? Well, sort of. We would like to reduce the memory of the parameters as well, and we’ve seen that we don’t need to wait for the entire all-gather to start the forward, we can already start the forward once we get the first layer.. here comes ZeRO-3!
513
 
514
- ### ZeRO-3: Adding Parameter **Partitioning**
515
 
516
  For Stage 3 we extend the above approach of sharding tensors over DP replicas up to sharding the model’s parameters.
517
 
@@ -520,22 +526,27 @@ For Stage 3 we extend the above approach of sharding tensors over DP replicas up
520
 
521
  So how do we do a forward or backward pass in practice if all parts of the model are distributed? Quite simply we gather them on-demand when we need them. In the forward pass this looks as follows:
522
 
523
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2018.png)
524
 
525
  So as we perform the forward pass and sequentially go through the layers we retrieve the necessary parameters on demand and immediately flush them from memory when we don’t need them anymore. The backward pass works the same way just inverted in flow and we produce the gradient shards:
526
 
527
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2019.png)
528
 
529
  During the forward pass we do all-gather operations for the parameters when we need them, so a $\Psi$ communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another $\Psi$ in communication tax. Finally we need the same ***reduce-scatter*** as in ZeRO-2 for the gradients which costs also $\Psi$ in communication and we arrive at a total communication cost of $3\Psi$, compared to $2\Psi$ for Zero-2.
530
 
531
  The other issue is that we need to do these all-gathers continuously throughout the forward and backward step, which amounts to `2 * num_layers - 1` additional all-gathers in a training step compared to Zero-2 as we can see in the following figure:
532
 
533
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2020.png)
534
 
535
  $$
536
  \frac{t_{comm}}{t_{compute}} = \frac{(DP-1) \cdot peak_{flops}}{2 \cdot seq \cdot mbs \cdot peak_{bw}}
537
  $$
538
 
 
 
 
 
 
539
  Overall it may sound like we significantly increase communication overhead, but thanks to **prefetching** we can start all-gathering weights for Layer n+1 while we do the current forward for Layer n which usually overlaps communication and computation as long as we don’t scale DP too much (as a rule of thumb: DP<512).
540
 
541
  In terms of memory we can see that our equation now reached it’s final form of $\frac{2\Psi +2\Psi+k\Psi}{N_d}$ which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t specifically help with the intermediate activations that we discussed in the previous chapter. ZeRO is an orthogonal technique to the activation checkpointing and gradient accumulation we discussed in other chapters.
@@ -547,7 +558,7 @@ In terms of memory we can see that our equation now reached it’s final form of
547
 
548
  However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only reduce the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with e.g. short sequence length.
549
 
550
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2021.png)
551
 
552
  As model grow bigger and even a single layer may not fit in GPU, we need more tool in our distributed training toolbox to scale more.
553
 
@@ -582,19 +593,19 @@ This means that we can compute matrix product by either 1) multiplying each colu
582
 
583
  In practice a small example of the operation looks like this:
584
 
585
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2022.png)
586
 
587
  Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.
588
 
589
  Our first option is to use column-wise sharding (also called ***column-linear***): We'll copy the complete input matrices to each worker, requiring an operation called [***broadcast***](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21), and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an [***all-gather](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21)*** operation*.*
590
 
591
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2023.png)
592
 
593
  The second option is called row-wise sharding (also called ***row-linear***): As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, which needs a ***scatter*** operation rather than a broadcast as used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, thus requiring an all-reduce operation in this scenario.
594
 
595
  We see here our fourth distributed primitive: ***s[catter](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21)***!
596
 
597
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2024.png)
598
 
599
  ## Tensor Parallelism in a Transformer Block
600
 
@@ -602,7 +613,7 @@ To come up with a strategy to follow, let’s move from a toy example to a real
602
 
603
  The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks.
604
 
605
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2025.png)
606
 
607
  Now that we’ve found the most efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA).
608
 
@@ -610,25 +621,29 @@ We can generally follow a similar approach where Q, K, and V matrices are split
610
 
611
  It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.
612
 
613
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2026.png)
614
 
615
  Finally note that there is a tradeoff in terms of communication as we’ve added several distributed communication primitive directly in the computation path of our model. At the difference of ZeRO where we could prefetch, it can be harder to make these communication fully overlap with computations.
616
 
617
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2027.png)
 
 
618
 
619
- Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This *exposed communication* overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied. This illustrates one of the key challenges with tensor parallelism - while it helps distribute large matrix multiplications, it does not actually reduce the total memory pressure since activations still need to be gathered for operations like LayerNorm. Additionally, it introduces communication requirements that heavily depend on the network infrastructure. The inability to hide this particular AllReduce behind computation means it directly adds to the critical path of the forward pass.
620
 
621
- ![Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2028.png)
 
 
622
 
623
  Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.
624
 
625
  In practice, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. As shown in the throughput plot above, we observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. This illustrates how communication costs can dominate at higher degrees of parallelism.
626
 
627
- However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients and optimizer states across GPUs. Let's examine this effect on a 70B parameter model:
628
 
629
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2029.png)
630
 
631
- As we can see, increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU. However, the activation memory remains constant across TP configurations. This is because operations like layer normalization and dropout require gathering the full activations on each GPU, effectively negating the memory savings we gained by sharding activations in the attention and feedforward layers.
632
 
633
  > One interesting note about layer normalization in tensor parallel training - since each TP rank sees the same activations after the all-gather, the layer norm weights don't actually need an all-reduce to sync their gradients after the backward pass. They naturally stay in sync across ranks. However, for dropout operations, we must make sure to sync the random seed across TP ranks to maintain deterministic behavior.
634
  >
@@ -659,7 +674,7 @@ In practice we’ll go from the left diagram to the right:
659
 
660
  ![ in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
661
  in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
662
- SP region needs full hidden_dim](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2030.png)
663
 
664
  in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
665
  in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
@@ -677,7 +692,7 @@ In the backward pass:
677
  - "f*" is a no-op because gradients are already duplicated across ranks
678
  - "f" is an all-reduce to synchronize gradients
679
 
680
- These operations "f" and "f*" are called conjugate pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.
681
 
682
  For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.
683
 
@@ -733,41 +748,35 @@ And for the embedding layer
733
  s: unchanged | h: full (weight_out is full + **reduce-scatter** for correctness)
734
  s: **reduce-scatter** to sharded |
735
 
736
- Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).
737
-
738
  You can find an example of implementation of both column and row linear TP in picotron:
739
  [https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py](https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py)
740
 
741
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2031.png)
742
-
743
- If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops **IN EACH LAYER** (2 for Attention and 2 for MLP), as shown here for the MLP region:
744
 
745
  ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2032.png)
746
 
747
- Besides the fact that TP requires communication in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8)
748
 
749
- > Notice how all-gather is overlapped with “Y A1” thats thanks to this trick
750
- [https://github.com/huggingface/nanotron/blob/9055c664c28a3b430b4e53bfcb5a074068c90f2a/src/nanotron/parallel/tensor_parallel/functional.py#L169-L262](https://github.com/huggingface/nanotron/blob/9055c664c28a3b430b4e53bfcb5a074068c90f2a/src/nanotron/parallel/tensor_parallel/functional.py#L169-L262)
751
- and you can find more tricks [here](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487).
752
- >
753
 
754
- TODO: remove, Profiling:
755
 
756
- - TP
757
 
758
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2033.png)
759
 
760
- - Seq Parall
761
 
762
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2034.png)
763
 
764
- Allreduce takes almost double the duration (900us) of reducescatter and allgather (500us)
 
765
 
766
- Let’s compare throughput as we scale TP and TP/SP for a 3B model:
767
 
768
- ![Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2035.png)
769
 
770
- Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.
771
 
772
  Let’s summarize our observations:
773
 
@@ -775,8 +784,6 @@ Let’s summarize our observations:
775
  - the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone
776
  - the Torch memory fragmentation makes it hard for us to predict the exact peak reserved memory. For more details check memory_viz tool section. [TODO: add link]
777
 
778
- TODO (outro): TP can help sharding activs (sometimes on hidden_dim, sometimes on seq_dim) by sharding the big linears across ranks, but what if we want to scale sequence_length, our activs will still blow up in TP region. → Context parallelism
779
-
780
  **We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.**
781
 
782
  However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.
@@ -791,7 +798,7 @@ With Tensor Parallelism and Sequence Parallelism, we can reduce the memory requi
791
 
792
  Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:
793
 
794
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2036.png)
795
 
796
  Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.
797
 
@@ -799,7 +806,7 @@ Can we apply similar ideas to our sequence parallelism approach but inside in th
799
 
800
  The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model. Our focus here will be to reduce the activation memory footprint by splitting the long sequences, complementing parallelism strategies like TP which target the hidden dimension of the model.
801
 
802
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2037.png)
803
 
804
  Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just as in data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.
805
 
@@ -832,7 +839,7 @@ With this animation, it’s also immediately clear why the authors chose to call
832
 
833
  There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:
834
 
835
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2038.png)
836
 
837
  The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.
838
 
@@ -842,17 +849,17 @@ Let’s see if we can balance our computations better:
842
 
843
  We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called [Zig-Zag attention](https://arxiv.org/pdf/2311.09431) and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.
844
 
845
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2039.png)
846
 
847
  At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.
848
 
849
  We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:
850
 
851
- ![Context Parallelism using AllGather implementation](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2040.png)
852
 
853
  Context Parallelism using AllGather implementation
854
 
855
- ![Context Parallelism using All-to-All implementation](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2041.png)
856
 
857
  Context Parallelism using All-to-All implementation
858
 
@@ -864,13 +871,13 @@ TODO: add links to megatronlm(AllGather) and deepspeed(All2All) implementations
864
 
865
  In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:
866
 
867
- ![Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2042.png)
868
 
869
  Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.
870
 
871
  Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.
872
 
873
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2043.png)
874
 
875
  Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!
876
 
@@ -884,7 +891,7 @@ But maybe you start feeling a glimpse of the troubles to come: “sequentially
884
 
885
  Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:
886
 
887
- ![An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2044.png)
888
 
889
  An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.
890
 
@@ -904,7 +911,7 @@ Thankfully, various pipeline parallelism schemes have been designed to reduce th
904
 
905
  Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:
906
 
907
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2045.png)
908
 
909
  > Note: before the numbers indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure.
910
  >
@@ -963,11 +970,16 @@ Since the memory explosion is triggered by the activation we store for the backw
963
 
964
  This schedule is called **one-forward-one-backward** **(1F1B)** as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:
965
 
966
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2046.png)
967
 
968
  The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for $p$ micro-batches instead of $m$ which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.
969
 
970
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2047.png)
 
 
 
 
 
971
 
972
  A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.
973
 
@@ -1056,7 +1068,7 @@ Up to now we’ve sliced our model naively along the model depth dimensions, loc
1056
 
1057
  This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.
1058
 
1059
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2048.png)
1060
 
1061
  As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of $v$, where $v$ is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes.
1062
 
@@ -1067,13 +1079,13 @@ $$
1067
 
1068
  So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by 𝑣 so it’s a trade off. In the following plot you can see several configurations for a PP setup with $p=8$, where the special case of $m=1, v=1$ corresponds to naive pipeline parallelism and the configurations with $v=1$ are AFAB or 1F1B setups and $v \neq 1$ are interleaved configurations.
1069
 
1070
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2049.png)
1071
 
1072
  Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in details in [https://arxiv.org/abs/2211.05953](https://arxiv.org/pdf/2211.05953).
1073
 
1074
  You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.
1075
 
1076
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2050.png)
1077
 
1078
  However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!
1079
 
@@ -1083,17 +1095,17 @@ There are even more sophisticated ways to reduce the bubble more and reached clo
1083
 
1084
  Let’s very quickly see how this can work by detailing briefly the [ZeroBubble](https://arxiv.org/abs/2401.10241) work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):
1085
 
1086
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2051.png)
1087
 
1088
 
1089
 
1090
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2052.png)
1091
 
1092
  While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.
1093
 
1094
  DeepSeek’s DualPipe propose an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph
1095
 
1096
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2053.png)
1097
 
1098
  The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the [ZeroBubble](https://arxiv.org/abs/2401.10241) paper for a discussion of the heuristics and algorithms to perform such a scheduling.
1099
 
@@ -1105,7 +1117,7 @@ Mixture-of-expert models have gained some traction with models such as Mixtral o
1105
 
1106
  So whereas Context parallelism
1107
 
1108
- ![[https://arxiv.org/pdf/2407.06204](https://arxiv.org/pdf/2407.06204)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2054.png)
1109
 
1110
  [https://arxiv.org/pdf/2407.06204](https://arxiv.org/pdf/2407.06204)
1111
 
@@ -1145,7 +1157,7 @@ Combining ZeRO-3 and TP doesn’t raise any specific issues except how to organi
1145
 
1146
  # How to Find the Best Training Configuration
1147
 
1148
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2055.png)
1149
 
1150
  We’ve now covered all the parallelism techniques that are actually used to distribute and training larger models. There remain a general question: which ones should we choose and which ones are best combined? We touched a little bit on this at the end of the last section but in this section we will walk through the decision process step by step.
1151
 
@@ -1167,7 +1179,7 @@ Overall, most training schedule past a certain size of the models wil tend to co
1167
 
1168
  Let’s try synthesize the decision process into a relatively simple tree structure:
1169
 
1170
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2056.png)
1171
 
1172
  To explain briefly, data parallelism is the most efficient method, and you should always prioritize it when memory is not a concern. If communication is not a concern and you can keep the BS/GPU at a big enough value to make good use of the GPU MatMul, ZeRO is an easy method to remove memory bottlenecks and stay close to a simple DP implementation. However, on larger clusters you’ll probably be able to make efficient use for more 4D parallelism. In this case, starting with tensor parallelism is the most direct way to reduce memory usage and is generally faster than pipeline parallelism within a single node(8 GPUs). However, in scenarios with long contexts, the primary memory usage will tend to shifts from model weights, gradients, and optimizer states to activation values. In such cases, context parallelism becomes more beneficial than pipeline parallelism. Note that this is not an exact recipe and you should think of this more as a starting point of hyperparameters to run your own benchmarks. For instance sometimes TP mixed with PP can be more efficient, even if TP<8 and ZeRO-1/2 can make sense to mix in with 4D parallelism as well.
1173
 
@@ -1191,13 +1203,13 @@ Generally, GPUs have a very hierarchical organization. In this primer we’ll ke
1191
 
1192
  On the compute side, GPUs consist of an array of compute units called **Streaming Multiprocessors** (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see [https://resources.nvidia.com/en-us-tensor-core](https://resources.nvidia.com/en-us-tensor-core) for details), each capable of handling multiple threads simultaneously.
1193
 
1194
- ![Original figure from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing).](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2057.png)
1195
 
1196
  Original figure from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing).
1197
 
1198
  The memory side is also highly hierarchical with several layers of cache and memory: **Registers** are the smallest units and are private to the threads during executions, **Shared Memory** and **L1 cache are** shared between the threads running on a single SM, higher up is the **L2 cache** shared by all SMs, finally there is the **Global Memory** which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.
1199
 
1200
- ![Original figure from [https://www.youtube.com/watch?v=ZQKMZIP3Fzg](https://www.youtube.com/watch?v=ZQKMZIP3Fzg)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2058.png)
1201
 
1202
  Original figure from [https://www.youtube.com/watch?v=ZQKMZIP3Fzg](https://www.youtube.com/watch?v=ZQKMZIP3Fzg)
1203
 
@@ -1207,11 +1219,11 @@ A piece of code running on a core of the GPU is called a **kernel**. It can be w
1207
 
1208
  To run the kernel, you will also need a specific code part (called **host code**) which is executed on the **CPU**/host and will take care of preparing data allocations and loading data and code.
1209
 
1210
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2059.png)
1211
 
1212
  Figure 5: Host code for a CUDA kernel for adding two vectors from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing)
1213
 
1214
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2060.png)
1215
 
1216
  Figure 6: Device code containing the definition of the vector addition kernel from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing)
1217
 
@@ -1250,7 +1262,7 @@ def elu(x, alpha=1.0):
1250
 
1251
  The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns) :
1252
 
1253
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2061.png)
1254
 
1255
  However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the triton kernel generated by `@torch.compile` . To do so, you simply need to set the environment variable `TORCH_LOGS` to “output_code” :
1256
 
@@ -1309,7 +1321,7 @@ Here, `tl.program_id(0)` provides a unique block ID, that we use to determine wh
1309
 
1310
  When we benchmark the generated kernel using `triton.testing.Benchmark` we have the following performance :
1311
 
1312
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2062.png)
1313
 
1314
  This standalone kernel demonstrates superior performance with smaller sizes compared to `@torch.compile` but this is likely here just an artifact from the compilation time of torch. compile. In any case, instead of starting from scratch, we can focus on optimizing this generated kernel, saving us time in the process.
1315
 
@@ -1349,17 +1361,17 @@ __global__ void matmul_naive(int M, int N, int K, const float *A, const float *B
1349
 
1350
  Here’s an excellent visualization of the kernel from this fantastic [blogpost](https://siboehm.com/articles/22/CUDA-MMM) :
1351
 
1352
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2063.png)
1353
 
1354
  However, when profiling this kernel with a tool like `ncu`, we can see issues, including low memory throughput and uncoalesced memory accesses.
1355
 
1356
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2064.png)
1357
 
1358
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2065.png)
1359
 
1360
  The reason for this is that in this kernel, two threads in the same block with Thread IDs `(0, 0)` and `(1, 0)` (which will end up in the same warp) will both load from the same column of matrix `B` but different rows of matrix `A`. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with `i = 0`, thread `(0, 0)` will load $A_{0,0}$, and thread `(1, 0)` will load $A_{1,0}$. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.
1361
 
1362
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2066.png)
1363
 
1364
  To improve our kernel we can change the way the coordinates x and y are calculated like the following :
1365
 
@@ -1380,7 +1392,7 @@ Instead of using a 2D block, we switch to a 1D block and redefine how we determi
1380
 
1381
  When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and **the GPU's memory throughput has increased by approximately 10 times**.
1382
 
1383
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2067.png)
1384
 
1385
  We also notice that the execution time of the kernel **decreases by 10x** !
1386
 
@@ -1394,7 +1406,7 @@ In matrix multiplication for example, each thread in a block may need elements f
1394
 
1395
  In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size `BLOCK_SIZE_M` by `BLOCK_SIZE_K`) and a tile of matrix B (of size `BLOCK_SIZE_K` by `BLOCK_SIZE_N`). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.
1396
 
1397
- ![From [https://cnugteren.github.io/tutorial/pages/page4.html](https://cnugteren.github.io/tutorial/pages/page4.html)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2068.png)
1398
 
1399
  From [https://cnugteren.github.io/tutorial/pages/page4.html](https://cnugteren.github.io/tutorial/pages/page4.html)
1400
 
@@ -1438,7 +1450,7 @@ When benchmarking this kernel using ncu, we noticed that the memory throughput i
1438
 
1439
  The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:
1440
 
1441
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2069.png)
1442
 
1443
  The meaning of the states can be found in the [Profiling Guide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference), specifically in the **Warp Stall Reasons** section. There we can read that :
1444
 
@@ -1463,13 +1475,13 @@ Flash attention is a technique pioneered by [Tri Dao](https://tridao.me) that op
1463
 
1464
  A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:
1465
 
1466
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2070.png)
1467
 
1468
  Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!
1469
 
1470
  The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of $O$ directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.
1471
 
1472
- ![From the FLASH-ATTENTION paper ([https://arxiv.org/pdf/2205.14135](https://arxiv.org/pdf/2205.14135))](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2071.png)
1473
 
1474
  From the FLASH-ATTENTION paper ([https://arxiv.org/pdf/2205.14135](https://arxiv.org/pdf/2205.14135))
1475
 
@@ -1527,13 +1539,13 @@ The principle of floating point numbers can be easily illustrated by recalling t
1527
 
1528
  Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:
1529
 
1530
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2072.png)
1531
 
1532
  We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.
1533
 
1534
  How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:
1535
 
1536
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2073.png)
1537
 
1538
  We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.
1539
 
@@ -1563,7 +1575,7 @@ Even if we perfectly overlap communication with computation, we always eventuall
1563
 
1564
  Recent research - including [FP8-LM](https://arxiv.org/abs/2310.18313), [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8#torchaofloat8), and [DeepSeek-V3](https://arxiv.org/abs/2412.19437) - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.
1565
 
1566
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2074.png)
1567
 
1568
  As [[Wortsman et al.]](https://arxiv.org/abs/2309.14322) observed, instability increases as learning rates rise for a fixed model size, making FP8 pretraining particularly tricky.
1569
 
@@ -1701,7 +1713,7 @@ Throughout this blogpost we’ll scale LLM training from one to hundreds of GPUs
1701
 
1702
  The general setup is that we have a number of independent nodes which could be CPU cores, GPUs, or compute nodes. Each performs some computation and then we want to communicate the result or parts of it to the other nodes for the next computation step (t+1).
1703
 
1704
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2075.png)
1705
 
1706
  Maybe we need to send the result from one node to all other nodes, or we need to sum all the intermediate results from each node to report the overall result. Usually, there is one node with an elevated status that plays a central role, here denoted with `root` that is the target or source of some operations. Let’s start with one of the simplest primitives: a broadcast operation.
1707
 
@@ -1709,7 +1721,7 @@ Maybe we need to send the result from one node to all other nodes, or we need to
1709
 
1710
  A very common pattern is that you have some data on Node 1 and you want to share it with all the other nodes so they can do some computation with the data. The broadcast operation does just that:
1711
 
1712
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2076.png)
1713
 
1714
  Collective operations are natively provided by PyTorch so we can easily write a small example that demonstrates how broadcasting works. We first need to initialize a process group with `dist.initi_process_group` which sets up the communication backend (we’ll talk about NCCL later), it determines how many workers (aka nodes) exists and assigns a rank to each one (which we can get with `dist.get_rank`). Finally, it establishes a connection between the workers.
1715
 
@@ -1754,7 +1766,7 @@ Great, seems like it works as expected. Note that the rank messages can be print
1754
 
1755
  Reduce patterns are among the most fundamental patterns in distributed data processing. The idea is that you want to combine the data present on each node through a function `f()` which can be for instance summation or averaging. In the Reduce paradigm the result is sent to the root node only, whereas in the AllReduce case the result is broadcasted to all nodes:
1756
 
1757
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2077.png)
1758
 
1759
  Of course no magic “free flying” node that can perform this operation and generally each node does a partial computation in a ring or tree structure of the nodes. Here is a simple example: let’s say we need to compute a sum of numbers on each nodes and our nodes are connected in a ring pattern. The first node sends its number to a neighbour which adds its number to the received number before forwarding it to the next neighbour. At the end of a round along the ring of nodes, the first node will receive the total sum.
1760
 
@@ -1849,7 +1861,7 @@ Now let’s turn to our next distributed communication operation. In many real c
1849
 
1850
  Gather and AllGather are quite similar to the Broadcast in that they allow distributing data among node without modification. The main difference to Broadcast is that there is not one value we need to share from one node to all other nodes but each node has an individual chunk of data that we want to either gather all data on one node (in case of Gather) or gather all data on all nodes (in the case of AllGather). A picture being worth 1000 words, let’s take a look:
1851
 
1852
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2078.png)
1853
 
1854
  Note that the dashed lines indicate that some data actually doesn’t move at all (since it’s already present on the node).
1855
 
@@ -1923,7 +1935,7 @@ As the name subtly suggests, the goal of the Scatter operation is to take data o
1923
 
1924
  The ReduceScatter pattern is slightly more complex: imagine you apply an operation like in the Reduce case but instead of moving the result to just one node we also distribute it evenly to all nodes:
1925
 
1926
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2079.png)
1927
 
1928
  The Scatter operation is written in code as the opposite of the Gather: instead of preparing a list of tensors as target we prepare the source data as a list of tensors we want to distribute. We also need to specify the `src`:
1929
 
@@ -1998,7 +2010,7 @@ We now have seen the main building block of distributed operations but before we
1998
 
1999
  The Barrier is a simple operation to synchronize all nodes. A barrier is not lifted until all nodes have reached it. Then only are they allowed to continue with further computations:
2000
 
2001
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2080.png)
2002
 
2003
  We can easily simulate delayed nodes by setting up a different sleep time on each node and see how long it takes for all of them to pass the barrier:
2004
 
@@ -2111,7 +2123,7 @@ print(p.key_averages().table(sort_by="cuda_time_total", row_limit=8))
2111
 
2112
  This would print aggregated profiling results sorted by the total CUDA time, and the output would be:
2113
 
2114
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2081.png)
2115
 
2116
  You can also try to inspect the trace as we previously mentioned on `chrome://tracing/`
2117
 
@@ -2120,7 +2132,7 @@ You can also try to inspect the trace as we previously mentioned on `chrome://t
2120
 
2121
  After zooming in, you can observe the flow of operations when calling `layer_norm` in this trace:
2122
 
2123
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2082.png)
2124
 
2125
  The sequence begins in the CPU (the upper section) with `aten::layer_norm`, progressing to `aten::native_layer_norm`, and then transitioning to `cudaLaunchKernel`. From there, we move on to the GPU, where the `vectorized_layer_norm_kernel` kernel is called.
2126
 
@@ -2141,7 +2153,7 @@ ncu --set full -o output python layer_norm.py
2141
 
2142
  and open the file `output.ncu-rep` with Nsight Compute, you will have a view that looks like this :
2143
 
2144
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2083.png)
2145
 
2146
  With clear warnings about compute and memory utilization, and how to make the kernel better in balancing compute and memory and achieve maximal occupancy.
2147
 
@@ -2223,7 +2235,7 @@ $$
2223
 
2224
  The chain rule applies here since the loss (L) depends directly on the output (Y). This equation is telling us that to get the gradient of the loss with respect to our input (dL/dX), we multiply the gradient of the loss with respect to the output (dL/dY) by our weight matrix (W).
2225
 
2226
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2084.png)
2227
 
2228
  Likewise, we can use chain rule to compute the gradient w.r.t to the weight:
2229
 
@@ -2231,7 +2243,7 @@ $$
2231
  \frac{dL}{dW} = \frac{dL}{dY} \frac{dY}{dW} = \frac{dL}{dY} X
2232
  $$
2233
 
2234
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2085.png)
2235
 
2236
  Here is a snippet of code to clarify all the concepts above:
2237
 
@@ -2310,13 +2322,13 @@ if __name__ == "__main__":
2310
  example_column_row_linear()
2311
  ```
2312
 
2313
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2086.png)
2314
 
2315
  **TODO** add these illustrations somewhere? I found them helpful:
2316
 
2317
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2087.png)
2318
 
2319
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2088.png)
2320
 
2321
  ## A3: ZeRO-R
2322
 
@@ -2461,7 +2473,7 @@ def example_gelu():
2461
 
2462
  ### Interconnect
2463
 
2464
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2089.png)
2465
 
2466
  ## How to profile your code
2467
 
@@ -2497,7 +2509,7 @@ with profiler: # step 2. Wrap the training with profiler
2497
 
2498
  After running this code, you will find `*.trace.json` files under the `profiler_out_dir`. To visualize the results, the easiest way is to open Google Chrome, go to `chrome://tracing/`, and drag the file into it. This will allow you to view the profiling results. To get more details, we invite you to check out the amazing [**tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)** created by PyTorch.
2499
 
2500
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2090.png)
2501
 
2502
  ## Formulas for compute / comms the balanhe balance
2503
 
@@ -2610,7 +2622,7 @@ for a single microbatch:
2610
 
2611
  ```
2612
 
2613
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2091.png)
2614
 
2615
  ## Integrating Context Parallelism with TP/SP
2616
 
@@ -2623,7 +2635,7 @@ In order to integrate CP with TP/SP we just have to:
2623
  3. **Replace standard attention with ring attention:** During the forward pass, each TP rank relies on the ring attention to compute the correct attention results during both the forward and backward passes. So all CP ranks within TP=0 for example need to all-gather the full KV sequence and calculate attention, but we store only the KV of a sequence chunk to reduce memory activations by CP.
2624
 
2625
  ![TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
2626
- TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2092.png)
2627
 
2628
  TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
2629
  TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank
@@ -2636,7 +2648,7 @@ In fact, given an activation value of shape$[ \text{batch\_size}, \text{sequence
2636
 
2637
  However, through extensive experimentation, we identified two effective training recipes that allowed us to **fully pretrain a 1B LLaMA model in FP8**, covering both the forward and backward passes, while using an FP8 optimizer. More importantly, our approach successfully matched LLaMA-2’s pretraining learning rate. The result?
2638
 
2639
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2093.png)
2640
 
2641
  A loss curve that perfectly matches mixed-precision bfloat16 (bfloat16 with FP32 master weights as the baseline). We successfully tested this to train a 1B LLaMA up to 100B tokens and a 7B LLaMA up to 25B tokens.
2642
 
@@ -2674,14 +2686,28 @@ Let’s take a moment to look better at this fundamental tool for distributed tr
2674
 
2675
  **Non-overlapping:** If we don't overlap the communication and computation, each computation (represented by the purple block) can only begin after the communication (green block) is complete and total time is the sum of communication and computation.
2676
 
2677
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2094.png)
2678
 
2679
  **Overlapping:** However, if we manage to launch communication and computation in parallel, we eliminate the waiting time! Now we can see that the computation (green block) is launched immediately, one after the other. In this case the total time is *only* the sum of computations.
2680
 
2681
- ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2095.png)
2682
 
2683
  Context parallelism has helped us going past the intra-node interconnect bottleneck, which blocked us from scaling TP across nodes. However, as you probably noted, it only helps reducing the memory constraints if the activation memory dominates the memory budget due to long sequences. What if we are not working on super long sequences and the model weights alone are too big for a single node?
2684
 
2685
  Well it turns out we have an other –quite different– option called pipeline parallelism (PP) which the time has come to explore now.
2686
 
2687
- [TODO: comment from Nouamane on comms overlapping with DP 512]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  As you can see, there’s a lot of ground to be covered. Before getting into the trenches of distributed training let’s take a quick high level look on we’ll cover in the post.
30
 
31
+ # ✅ TL;DR
32
 
33
  This book is very extensive so we decide to start with a very general overview of how you can think about distributed training. At a high level, the key challenge in scaling LLM training is to make a training step (forward/backward/optimizer step) with a large batch size the fastest possible.
34
 
 
55
 
56
  Now that we nailed a few key concept and terms let’s get started by revisiting the basic training steps of an LLM!
57
 
58
+ # ✅ First Steps: Training on one GPU
59
 
60
  Let’s start by quickly reviewing the very basics of model training before we start to scale to many GPUs. When a model is trained on a single GPU, the training typically consists of three steps:
61
 
 
101
 
102
  Let’s start by quickly understanding what led to our out-of-memory issue in the first place. This will help us gain some useful intuitions for later.
103
 
104
+ ## ✅  Memory usage in Transformers
105
 
106
  When training a neural network model, one store several items in memory:
107
 
 
143
 
144
  $$
145
 
146
+ N = 2*h*v + L * (12 * h^2 + 13*h) + 2*h
147
  $$
148
 
149
  > Note: we excluded the positional embedding count as rotary embeddings are not learned.
150
  >
151
 
152
+ In that equation, $h$ is the hidden dimension, $v$ the vocabulary size, and $L$ the number of layers in the model. The first term is the parameter count for the word embedding and LM head. When they are tied (meaning we use the same parameters for both), it would be 1. This is beneficial for small models, as the vocabulary size is generally much larger than the hidden dimension. We dont want the number of parameters in an LLM to be dominated by the embedding layer.
153
 
154
  Memory requirements for the parameters and gradients are simply the number of parameters multiplied by the number of bytes per parameter. In good old-fashioned full precision (FP32) training both parameters and gradients require 4 bytes while the optimizer, if we use Adam, requires the momentum and variance to be stored, which adds another two 4 bytes per parameter. In summary:
155
 
 
173
  m_{opt} = (4+4) * N
174
  $$
175
 
176
+ > Some libraries store grads in fp32 which would require an additional $m_{params\_fp32} = 4 * N$ memory. This is done for example in nanotron, because `bf16` is lossy for smaller values and we always prioritize stability. See https://github.com/microsoft/DeepSpeed/issues/1773 for more information.
177
  >
178
 
179
  Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as having the model which does the forward/backward in half precision it allows us to (1) use optimized lower precision operations on the GPU which are faster and (2) reduces the activation memory requirements during the forward pass.
 
203
 
204
  $$
205
 
206
+ Here $L$ is the number of layers, $seq$ the sequence length, $bs$ the batch size in samples, $h$ the hidden dimension of the model and $n_{heads}$ the number of heads.
207
 
208
  For the exact numbers derivation, you can follow this [NVIDIA pape](https://arxiv.org/pdf/2205.05198)r on recomputation, it essentially requires you to do some accounting of all the sizes of intermediate activations between each operation.
209
 
 
219
 
220
  It’s time to explain our first technique – called ***activation recomputation**–* ****which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.
221
 
222
+ ## ✅ **Activation recomputation**
223
 
224
  The general idea behind ***activation recomputation** –*also called ***gradient checkpointing*** or ***rematerialization**– *****is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:
225
 
 
252
 
253
  However, activations still bears a linear dependance on the batch size and all our profiles in the barplots above were using `bs=1` so as we move to larger batch sizes it might become an issue again. Do not despair as we have a second tool in our box - ***gradient accumulation*** to the rescue!
254
 
255
+ ## ✅ Gradient accumulation
256
 
257
  Now that we’ve used activation recomputation to fit our model with a small batch size on a single GPU, we still need to reach our target batch size, let’s say 1M tokens (see our earlier discussion on optimal batch size). Gradient accumulation is a very straightforward method to avoid memory explosion when doing this.
258
 
 
281
 
282
  Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called ***data parallelism** which is just a parallel version of gradient accumulation*.
283
 
284
+ TODO: intro for torch.profiler section
285
 
286
+ ## ✅ torch.profiler
287
 
288
  ![**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%207.png)
289
 
 
297
 
298
  **Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**
299
 
300
+ # 🚧 Data Parallelism
301
 
302
  The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism.
303
 
 
310
  > If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].
311
  >
312
 
 
 
 
 
 
313
  TODO: embed naive DP: [https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60](https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60)
314
 
315
  TODO: embed bucket DP: [https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171](https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171)
 
322
 
323
  Let’s see three optimizations that are done in practice for this!
324
 
325
+ ### 🚧  **First optimization:** Overlap gradient synchronization with backward pass
326
 
327
  The main drawback of the naive DDP approach we’ve just described is that after the backward pass (*computation*), we have to wait for gradient synchronization (*communication*) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!
328
 
 
347
 
348
  This is our first example of “*overlapping computation and communication*” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency.
349
 
350
+ ### 🚧 **Second optimization:** Bucketing gradients
351
 
352
  But we can even go further. For a given number of parameters to synchronize, GPU operations like collective communications are often more efficient when performing few calls on large tensors rather than many calls on smaller tensors. Therefore, instead of performing independent all-reduce for each gradient, we can group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket. Think of it like packing items into boxes before shipping—it's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.
353
 
 
357
 
358
  [TODO: benchmark all reduce with different size / bucket size results ?]
359
 
360
+ ### 🚧 **Third optimization: I**nterplay with gradient accumulation
361
 
362
  As we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with `optimizer.step()`. When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.
363
 
 
365
 
366
  In PyTorch, this is typically solved by adding a [`model.no_sync()`](https://github.com/pytorch/pytorch/blob/5ea67778619c31b13644914deef709199052ee55/torch/nn/parallel/distributed.py#L1408-L1435) decorator, which disables gradient synchronization, on the backward passes which don’t need reduction.
367
 
368
+ > When performing communication operations, tensors must be contiguous in memory. To avoid redundant memory copies during communication, ensure that tensors that will be communicated are stored contiguously in memory.
369
+ Sometimes we need to allocate additional continuous buffers of the size of activations or model parameters specifically for communication, which contributes to the peak memory usage during training.
370
+ >
371
+
372
+ ## 🚧 Revisit global batch size
373
 
374
  Let’s update our batch size equation with our newly learned Data Parallelism and Gradient Accumulation parameters:
375
 
 
386
 
387
  Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 3 more dimensions).
388
 
389
+ ## 🚧 Our journey up to now
390
 
391
  Let’s quickly summarize what we’ve seen up to now and how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:
392
 
 
405
 
406
  Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!
407
 
408
+ > Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by *ring latency* (time required for a signal to propagate once around the ring) **which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
409
  >
410
 
411
  TODO: We’re gaining overall throughput but losing efficiency as we scale DP too much
412
 
413
  ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2013.png)
414
 
415
+ $$
416
+ t_{comm}/t_{compute} = \frac{\text{num\_params}}{\text{2 num\_tokens}} \cdot \left(\frac{DP-1}{DP}\right) \cdot \frac{\text{peak\_flops}}{\text{peak\_bw}} \leq 1
417
+ $$
418
 
419
+ **We’ve explored data parallelism, our first (simple) strategy to scale training across more GPUs. It works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!**
420
 
421
+ The keen reader have already probably notes however that this assumes that we can fit at least one input sample forward pass (mbs*=1)* into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated.
422
 
423
  ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2014.png)
424
 
425
+ > Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)
426
+ >
427
+
428
  Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some of these tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices!
429
 
430
  There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined! The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!
431
 
432
+ ## 🚧 ZeRO (**Ze**ro **R**edundancy **O**ptimizer)
433
 
434
  In this section we will introduce DeepSpeed ZeRO (**Ze**ro **R**edundancy **O**ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.
435
 
 
451
 
452
  Let’s have a closer look how much we can save with the partitioning of each ZeRO stage!
453
 
454
+ ### 🚧 Memory usage revisited
455
 
456
  Let’s first recap the memory usage of optimizer states, gradients, and parameters during a standard training. Let’s define the number of our model's parameters as $\Psi$ (previously N but here we use the original ZeRO notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:
457
 
 
470
 
471
  Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.
472
 
473
+ ### 🚧 ZeRO-1: Partitioning Optimizer States
474
 
475
  In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?
476
 
 
482
 
483
  - Forward pass with all bf16 parameters (but different microbatches across DP ranks)
484
  - Backward pass with all gradients (but different microbatches across DP ranks)
485
+ - Perform an reduce-scatter **[ADD link!]** on the gradients (reduce-scatter is 2 times faster than all reduce! *yay, a third communication primitive!*)
486
  - Each replica perform an optimizer step (has only 1/$N_d$ optimizer states) updates only on 1/$N_d$ of fp32 parameters, and then 1/$N_d$ of bf16 parameters
487
  - [New operation in ZeRO, not in vanilla DP] Perform an all-gather of bf16 parameters to send missing slices back to each replica
488
 
489
  ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2016.png)
490
 
491
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2017.png)
492
+
493
  If you've been following along, you'll recall from vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of bf16 parameters. There are two main strategies for this:
494
 
495
+ 1. During optimizer step: We can initiate the all-gather immediately after the optimizer updates part of the parameters. This allows the communication to potentially overlap with other parameters update.
496
+ 2. During forward: We can overlap the all-gather of each layer’s parameters with the forward pass.
497
 
498
  But unfortunately these techniques are not as evident to implement as they seem and require sophisticated use of hooks / bucketing. In practice we can just use Zero3 / FSDP implementation where the FSDPUnit is the entire model, more details about this later..
499
 
500
+ ### 🚧 ZeRO-2: Adding **Gradient Partitioning**
501
 
502
  In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates $\frac{1}{N_d}$ of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place since only a subset is needed for the optimization step.
503
 
504
+ → During the backward pass, instead of performing an all-reduce over the gradients, we only perform a ***reduce-scatter*** operation! **Where we only spread the $\frac{1}{N_d}$ gradients needed in memory, thus saving more memory compared to ZeRO-1
505
 
506
  > In case of FP32 gradient accumulation, we only need to keep $\frac{1}{N_d}$ fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the $\frac{1}{N_d}$ fp32_grads.
507
  >
 
510
 
511
  It’s easy to see now that sharding the gradients leads to to $2\Psi + \frac{2\Psi+k\Psi}{N_d}$ and as $N_d$ is increased we can save up to 8x memory over the baseline. In terms of communication the same process applies as for ZeRO-1, with the only difference that we communicate and release on the fly. In total, ZeRO-2 is thus also equivalent to vanilla DP training w.r.t. communication.
512
 
513
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2018.png)
514
 
515
  > Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option. The reason some distributed training frameworks don’t support it is that gradient sharding may interfere with and make more complex other parallel strategies we discussed later.
516
  >
517
 
518
  Now that we’ve sharded gradients as well, we are we done? Or can we keep getting away with this? Well, sort of. We would like to reduce the memory of the parameters as well, and we’ve seen that we don’t need to wait for the entire all-gather to start the forward, we can already start the forward once we get the first layer.. here comes ZeRO-3!
519
 
520
+ ### 🚧 ZeRO-3: Adding Parameter **Partitioning**
521
 
522
  For Stage 3 we extend the above approach of sharding tensors over DP replicas up to sharding the model’s parameters.
523
 
 
526
 
527
  So how do we do a forward or backward pass in practice if all parts of the model are distributed? Quite simply we gather them on-demand when we need them. In the forward pass this looks as follows:
528
 
529
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2019.png)
530
 
531
  So as we perform the forward pass and sequentially go through the layers we retrieve the necessary parameters on demand and immediately flush them from memory when we don’t need them anymore. The backward pass works the same way just inverted in flow and we produce the gradient shards:
532
 
533
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2020.png)
534
 
535
  During the forward pass we do all-gather operations for the parameters when we need them, so a $\Psi$ communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another $\Psi$ in communication tax. Finally we need the same ***reduce-scatter*** as in ZeRO-2 for the gradients which costs also $\Psi$ in communication and we arrive at a total communication cost of $3\Psi$, compared to $2\Psi$ for Zero-2.
536
 
537
  The other issue is that we need to do these all-gathers continuously throughout the forward and backward step, which amounts to `2 * num_layers - 1` additional all-gathers in a training step compared to Zero-2 as we can see in the following figure:
538
 
539
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2021.png)
540
 
541
  $$
542
  \frac{t_{comm}}{t_{compute}} = \frac{(DP-1) \cdot peak_{flops}}{2 \cdot seq \cdot mbs \cdot peak_{bw}}
543
  $$
544
 
545
+ $$
546
+
547
+ t_{comm}^{FSDP}/t_{compute} = \frac{\text{seq} \cdot \text{mbs}}{2} \cdot (DP-1) \cdot \frac{\text{peak\_flops}}{\text{peak\_bw}} \leq 1
548
+ $$
549
+
550
  Overall it may sound like we significantly increase communication overhead, but thanks to **prefetching** we can start all-gathering weights for Layer n+1 while we do the current forward for Layer n which usually overlaps communication and computation as long as we don’t scale DP too much (as a rule of thumb: DP<512).
551
 
552
  In terms of memory we can see that our equation now reached it’s final form of $\frac{2\Psi +2\Psi+k\Psi}{N_d}$ which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t specifically help with the intermediate activations that we discussed in the previous chapter. ZeRO is an orthogonal technique to the activation checkpointing and gradient accumulation we discussed in other chapters.
 
558
 
559
  However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only reduce the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with e.g. short sequence length.
560
 
561
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2022.png)
562
 
563
  As model grow bigger and even a single layer may not fit in GPU, we need more tool in our distributed training toolbox to scale more.
564
 
 
593
 
594
  In practice a small example of the operation looks like this:
595
 
596
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2023.png)
597
 
598
  Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.
599
 
600
  Our first option is to use column-wise sharding (also called ***column-linear***): We'll copy the complete input matrices to each worker, requiring an operation called [***broadcast***](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21), and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an [***all-gather](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21)*** operation*.*
601
 
602
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2024.png)
603
 
604
  The second option is called row-wise sharding (also called ***row-linear***): As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, which needs a ***scatter*** operation rather than a broadcast as used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, thus requiring an all-reduce operation in this scenario.
605
 
606
  We see here our fourth distributed primitive: ***s[catter](https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21)***!
607
 
608
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2025.png)
609
 
610
  ## Tensor Parallelism in a Transformer Block
611
 
 
613
 
614
  The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks.
615
 
616
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2026.png)
617
 
618
  Now that we’ve found the most efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA).
619
 
 
621
 
622
  It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.
623
 
624
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2027.png)
625
 
626
  Finally note that there is a tradeoff in terms of communication as we’ve added several distributed communication primitive directly in the computation path of our model. At the difference of ZeRO where we could prefetch, it can be harder to make these communication fully overlap with computations.
627
 
628
+ ![Forward pass in Tensor Parallelism](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2028.png)
629
+
630
+ Forward pass in Tensor Parallelism
631
 
632
+ Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This *exposed communication* overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied.
633
 
634
+ Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, it introduces significant communication requirements that heavily depend on the network infrastructure. The inability to hide this particular AllReduce behind computation means it directly adds to the critical path of forward propagation.
635
+
636
+ ![Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2029.png)
637
 
638
  Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.
639
 
640
  In practice, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. As shown in the throughput plot above, we observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. This illustrates how communication costs can dominate at higher degrees of parallelism.
641
 
642
+ However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients, optimizer states and activations (to some extent) across GPUs. Let's examine this effect on a 70B parameter model:
643
 
644
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2030.png)
645
 
646
+ As we can see, increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU. While tensor parallelism does help reduce activation memory in attention and feedforward layers by sharding the matrix multiplications across GPUs, we don't get the full memory benefits we could. This is because operations like layer normalization and dropout still require gathering the full activations on each GPU, partially negating the memory savings. We can do better by finding ways to parallelize these remaining operations as well.
647
 
648
  > One interesting note about layer normalization in tensor parallel training - since each TP rank sees the same activations after the all-gather, the layer norm weights don't actually need an all-reduce to sync their gradients after the backward pass. They naturally stay in sync across ranks. However, for dropout operations, we must make sure to sync the random seed across TP ranks to maintain deterministic behavior.
649
  >
 
674
 
675
  ![ in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
676
  in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
677
+ SP region needs full hidden_dim](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2031.png)
678
 
679
  in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
680
  in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
 
692
  - "f*" is a no-op because gradients are already duplicated across ranks
693
  - "f" is an all-reduce to synchronize gradients
694
 
695
+ These operations "f" and "f*" are called **conjugate** pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.
696
 
697
  For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.
698
 
 
748
  s: unchanged | h: full (weight_out is full + **reduce-scatter** for correctness)
749
  s: **reduce-scatter** to sharded |
750
 
 
 
751
  You can find an example of implementation of both column and row linear TP in picotron:
752
  [https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py](https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py)
753
 
754
+ By using sequence parallelism, we can achieve even greater activation memory savings, allowing us to push our batch size and sequence length further than what would be possible with tensor parallelism alone. Let's see what that means for our previous 70B model example:
 
 
755
 
756
  ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2032.png)
757
 
758
+ Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).
759
 
760
+ If you’ve been paying close attention, you’ll notice that were talking about **4 comms ops in each layer** (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:
 
 
 
761
 
762
+ ![Forward pass in Tensor + Sequence Parallelism](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2033.png)
763
 
764
+ Forward pass in Tensor + Sequence Parallelism
765
 
766
+ Besides the fact that TP requires communications in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8)
767
 
768
+ > Note: Overlapping communication with computation for TP is an [active area of research](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487), with recent work like Domino [TODO: cite domino paper] exploring novel techniques to maximize this overlap. For example, Megatron-LM/Nanotron implement a partial overlapping of all-gather with FC1 computation, and we expect to see more innovations in this space as the field continues to evolve.
769
 
770
+ $$
771
 
772
+ t_{comm}/t_{compute} = \frac{1}{24h} \cdot (TP-1) \cdot \frac{\text{peak\_flops}}{\text{peak\_bw}} \leq 1
773
+ $$
774
 
775
+ As you might expect, this communication overhead becomes increasingly problematic as we scale up tensor parallelism. To illustrate this, let’s check throughput as we scale TP with SP for a 3B model:
776
 
777
+ ![Impact of combined Tensor and Sequence Parallelism (TP/SP) on a 3B model’s performance and memory utilization with 4096 seqlen: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2034.png)
778
 
779
+ Impact of combined Tensor and Sequence Parallelism (TP/SP) on a 3B model’s performance and memory utilization with 4096 seqlen: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.
780
 
781
  Let’s summarize our observations:
782
 
 
784
  - the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone
785
  - the Torch memory fragmentation makes it hard for us to predict the exact peak reserved memory. For more details check memory_viz tool section. [TODO: add link]
786
 
 
 
787
  **We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.**
788
 
789
  However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.
 
798
 
799
  Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:
800
 
801
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2035.png)
802
 
803
  Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.
804
 
 
806
 
807
  The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model. Our focus here will be to reduce the activation memory footprint by splitting the long sequences, complementing parallelism strategies like TP which target the hidden dimension of the model.
808
 
809
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2036.png)
810
 
811
  Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just as in data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.
812
 
 
839
 
840
  There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:
841
 
842
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2037.png)
843
 
844
  The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.
845
 
 
849
 
850
  We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called [Zig-Zag attention](https://arxiv.org/pdf/2311.09431) and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.
851
 
852
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2038.png)
853
 
854
  At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.
855
 
856
  We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:
857
 
858
+ ![Context Parallelism using AllGather implementation](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2039.png)
859
 
860
  Context Parallelism using AllGather implementation
861
 
862
+ ![Context Parallelism using All-to-All implementation](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2040.png)
863
 
864
  Context Parallelism using All-to-All implementation
865
 
 
871
 
872
  In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:
873
 
874
+ ![Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2041.png)
875
 
876
  Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.
877
 
878
  Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.
879
 
880
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2042.png)
881
 
882
  Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!
883
 
 
891
 
892
  Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:
893
 
894
+ ![An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2043.png)
895
 
896
  An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.
897
 
 
911
 
912
  Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:
913
 
914
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2044.png)
915
 
916
  > Note: before the numbers indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure.
917
  >
 
970
 
971
  This schedule is called **one-forward-one-backward** **(1F1B)** as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:
972
 
973
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2045.png)
974
 
975
  The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for $p$ micro-batches instead of $m$ which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.
976
 
977
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2046.png)
978
+
979
+ $$
980
+
981
+ t_{comm}^{PP}/t_{compute} = \frac{1}{32h \cdot \text{num\_layers\_in\_next\_pp}} \cdot \frac{\text{peak\_flops}}{\text{peak\_bw}} \leq 1
982
+ $$
983
 
984
  A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.
985
 
 
1068
 
1069
  This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.
1070
 
1071
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2047.png)
1072
 
1073
  As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of $v$, where $v$ is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes.
1074
 
 
1079
 
1080
  So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by 𝑣 so it’s a trade off. In the following plot you can see several configurations for a PP setup with $p=8$, where the special case of $m=1, v=1$ corresponds to naive pipeline parallelism and the configurations with $v=1$ are AFAB or 1F1B setups and $v \neq 1$ are interleaved configurations.
1081
 
1082
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2048.png)
1083
 
1084
  Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in details in [https://arxiv.org/abs/2211.05953](https://arxiv.org/pdf/2211.05953).
1085
 
1086
  You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.
1087
 
1088
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2049.png)
1089
 
1090
  However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!
1091
 
 
1095
 
1096
  Let’s very quickly see how this can work by detailing briefly the [ZeroBubble](https://arxiv.org/abs/2401.10241) work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):
1097
 
1098
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2050.png)
1099
 
1100
 
1101
 
1102
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2051.png)
1103
 
1104
  While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.
1105
 
1106
  DeepSeek’s DualPipe propose an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph
1107
 
1108
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2052.png)
1109
 
1110
  The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the [ZeroBubble](https://arxiv.org/abs/2401.10241) paper for a discussion of the heuristics and algorithms to perform such a scheduling.
1111
 
 
1117
 
1118
  So whereas Context parallelism
1119
 
1120
+ ![[https://arxiv.org/pdf/2407.06204](https://arxiv.org/pdf/2407.06204)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2053.png)
1121
 
1122
  [https://arxiv.org/pdf/2407.06204](https://arxiv.org/pdf/2407.06204)
1123
 
 
1157
 
1158
  # How to Find the Best Training Configuration
1159
 
1160
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2054.png)
1161
 
1162
  We’ve now covered all the parallelism techniques that are actually used to distribute and training larger models. There remain a general question: which ones should we choose and which ones are best combined? We touched a little bit on this at the end of the last section but in this section we will walk through the decision process step by step.
1163
 
 
1179
 
1180
  Let’s try synthesize the decision process into a relatively simple tree structure:
1181
 
1182
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2055.png)
1183
 
1184
  To explain briefly, data parallelism is the most efficient method, and you should always prioritize it when memory is not a concern. If communication is not a concern and you can keep the BS/GPU at a big enough value to make good use of the GPU MatMul, ZeRO is an easy method to remove memory bottlenecks and stay close to a simple DP implementation. However, on larger clusters you’ll probably be able to make efficient use for more 4D parallelism. In this case, starting with tensor parallelism is the most direct way to reduce memory usage and is generally faster than pipeline parallelism within a single node(8 GPUs). However, in scenarios with long contexts, the primary memory usage will tend to shifts from model weights, gradients, and optimizer states to activation values. In such cases, context parallelism becomes more beneficial than pipeline parallelism. Note that this is not an exact recipe and you should think of this more as a starting point of hyperparameters to run your own benchmarks. For instance sometimes TP mixed with PP can be more efficient, even if TP<8 and ZeRO-1/2 can make sense to mix in with 4D parallelism as well.
1185
 
 
1203
 
1204
  On the compute side, GPUs consist of an array of compute units called **Streaming Multiprocessors** (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see [https://resources.nvidia.com/en-us-tensor-core](https://resources.nvidia.com/en-us-tensor-core) for details), each capable of handling multiple threads simultaneously.
1205
 
1206
+ ![Original figure from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing).](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2056.png)
1207
 
1208
  Original figure from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing).
1209
 
1210
  The memory side is also highly hierarchical with several layers of cache and memory: **Registers** are the smallest units and are private to the threads during executions, **Shared Memory** and **L1 cache are** shared between the threads running on a single SM, higher up is the **L2 cache** shared by all SMs, finally there is the **Global Memory** which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query.
1211
 
1212
+ ![Original figure from [https://www.youtube.com/watch?v=ZQKMZIP3Fzg](https://www.youtube.com/watch?v=ZQKMZIP3Fzg)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2057.png)
1213
 
1214
  Original figure from [https://www.youtube.com/watch?v=ZQKMZIP3Fzg](https://www.youtube.com/watch?v=ZQKMZIP3Fzg)
1215
 
 
1219
 
1220
  To run the kernel, you will also need a specific code part (called **host code**) which is executed on the **CPU**/host and will take care of preparing data allocations and loading data and code.
1221
 
1222
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2058.png)
1223
 
1224
  Figure 5: Host code for a CUDA kernel for adding two vectors from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing)
1225
 
1226
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2059.png)
1227
 
1228
  Figure 6: Device code containing the definition of the vector addition kernel from [https://blog.codingconfessions.com/p/gpu-computing](https://blog.codingconfessions.com/p/gpu-computing)
1229
 
 
1262
 
1263
  The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns) :
1264
 
1265
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2060.png)
1266
 
1267
  However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the triton kernel generated by `@torch.compile` . To do so, you simply need to set the environment variable `TORCH_LOGS` to “output_code” :
1268
 
 
1321
 
1322
  When we benchmark the generated kernel using `triton.testing.Benchmark` we have the following performance :
1323
 
1324
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2061.png)
1325
 
1326
  This standalone kernel demonstrates superior performance with smaller sizes compared to `@torch.compile` but this is likely here just an artifact from the compilation time of torch. compile. In any case, instead of starting from scratch, we can focus on optimizing this generated kernel, saving us time in the process.
1327
 
 
1361
 
1362
  Here’s an excellent visualization of the kernel from this fantastic [blogpost](https://siboehm.com/articles/22/CUDA-MMM) :
1363
 
1364
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2062.png)
1365
 
1366
  However, when profiling this kernel with a tool like `ncu`, we can see issues, including low memory throughput and uncoalesced memory accesses.
1367
 
1368
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2063.png)
1369
 
1370
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2064.png)
1371
 
1372
  The reason for this is that in this kernel, two threads in the same block with Thread IDs `(0, 0)` and `(1, 0)` (which will end up in the same warp) will both load from the same column of matrix `B` but different rows of matrix `A`. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with `i = 0`, thread `(0, 0)` will load $A_{0,0}$, and thread `(1, 0)` will load $A_{1,0}$. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.
1373
 
1374
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2065.png)
1375
 
1376
  To improve our kernel we can change the way the coordinates x and y are calculated like the following :
1377
 
 
1392
 
1393
  When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and **the GPU's memory throughput has increased by approximately 10 times**.
1394
 
1395
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2066.png)
1396
 
1397
  We also notice that the execution time of the kernel **decreases by 10x** !
1398
 
 
1406
 
1407
  In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size `BLOCK_SIZE_M` by `BLOCK_SIZE_K`) and a tile of matrix B (of size `BLOCK_SIZE_K` by `BLOCK_SIZE_N`). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed.
1408
 
1409
+ ![From [https://cnugteren.github.io/tutorial/pages/page4.html](https://cnugteren.github.io/tutorial/pages/page4.html)](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2067.png)
1410
 
1411
  From [https://cnugteren.github.io/tutorial/pages/page4.html](https://cnugteren.github.io/tutorial/pages/page4.html)
1412
 
 
1450
 
1451
  The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:
1452
 
1453
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2068.png)
1454
 
1455
  The meaning of the states can be found in the [Profiling Guide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference), specifically in the **Warp Stall Reasons** section. There we can read that :
1456
 
 
1475
 
1476
  A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:
1477
 
1478
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2069.png)
1479
 
1480
  Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!
1481
 
1482
  The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of $O$ directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.
1483
 
1484
+ ![From the FLASH-ATTENTION paper ([https://arxiv.org/pdf/2205.14135](https://arxiv.org/pdf/2205.14135))](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2070.png)
1485
 
1486
  From the FLASH-ATTENTION paper ([https://arxiv.org/pdf/2205.14135](https://arxiv.org/pdf/2205.14135))
1487
 
 
1539
 
1540
  Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:
1541
 
1542
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2071.png)
1543
 
1544
  We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.
1545
 
1546
  How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:
1547
 
1548
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2072.png)
1549
 
1550
  We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.
1551
 
 
1575
 
1576
  Recent research - including [FP8-LM](https://arxiv.org/abs/2310.18313), [torchao](https://github.com/pytorch/ao/tree/main/torchao/float8#torchaofloat8), and [DeepSeek-V3](https://arxiv.org/abs/2412.19437) - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.
1577
 
1578
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2073.png)
1579
 
1580
  As [[Wortsman et al.]](https://arxiv.org/abs/2309.14322) observed, instability increases as learning rates rise for a fixed model size, making FP8 pretraining particularly tricky.
1581
 
 
1713
 
1714
  The general setup is that we have a number of independent nodes which could be CPU cores, GPUs, or compute nodes. Each performs some computation and then we want to communicate the result or parts of it to the other nodes for the next computation step (t+1).
1715
 
1716
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2074.png)
1717
 
1718
  Maybe we need to send the result from one node to all other nodes, or we need to sum all the intermediate results from each node to report the overall result. Usually, there is one node with an elevated status that plays a central role, here denoted with `root` that is the target or source of some operations. Let’s start with one of the simplest primitives: a broadcast operation.
1719
 
 
1721
 
1722
  A very common pattern is that you have some data on Node 1 and you want to share it with all the other nodes so they can do some computation with the data. The broadcast operation does just that:
1723
 
1724
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2075.png)
1725
 
1726
  Collective operations are natively provided by PyTorch so we can easily write a small example that demonstrates how broadcasting works. We first need to initialize a process group with `dist.initi_process_group` which sets up the communication backend (we’ll talk about NCCL later), it determines how many workers (aka nodes) exists and assigns a rank to each one (which we can get with `dist.get_rank`). Finally, it establishes a connection between the workers.
1727
 
 
1766
 
1767
  Reduce patterns are among the most fundamental patterns in distributed data processing. The idea is that you want to combine the data present on each node through a function `f()` which can be for instance summation or averaging. In the Reduce paradigm the result is sent to the root node only, whereas in the AllReduce case the result is broadcasted to all nodes:
1768
 
1769
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2076.png)
1770
 
1771
  Of course no magic “free flying” node that can perform this operation and generally each node does a partial computation in a ring or tree structure of the nodes. Here is a simple example: let’s say we need to compute a sum of numbers on each nodes and our nodes are connected in a ring pattern. The first node sends its number to a neighbour which adds its number to the received number before forwarding it to the next neighbour. At the end of a round along the ring of nodes, the first node will receive the total sum.
1772
 
 
1861
 
1862
  Gather and AllGather are quite similar to the Broadcast in that they allow distributing data among node without modification. The main difference to Broadcast is that there is not one value we need to share from one node to all other nodes but each node has an individual chunk of data that we want to either gather all data on one node (in case of Gather) or gather all data on all nodes (in the case of AllGather). A picture being worth 1000 words, let’s take a look:
1863
 
1864
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2077.png)
1865
 
1866
  Note that the dashed lines indicate that some data actually doesn’t move at all (since it’s already present on the node).
1867
 
 
1935
 
1936
  The ReduceScatter pattern is slightly more complex: imagine you apply an operation like in the Reduce case but instead of moving the result to just one node we also distribute it evenly to all nodes:
1937
 
1938
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2078.png)
1939
 
1940
  The Scatter operation is written in code as the opposite of the Gather: instead of preparing a list of tensors as target we prepare the source data as a list of tensors we want to distribute. We also need to specify the `src`:
1941
 
 
2010
 
2011
  The Barrier is a simple operation to synchronize all nodes. A barrier is not lifted until all nodes have reached it. Then only are they allowed to continue with further computations:
2012
 
2013
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2079.png)
2014
 
2015
  We can easily simulate delayed nodes by setting up a different sleep time on each node and see how long it takes for all of them to pass the barrier:
2016
 
 
2123
 
2124
  This would print aggregated profiling results sorted by the total CUDA time, and the output would be:
2125
 
2126
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2080.png)
2127
 
2128
  You can also try to inspect the trace as we previously mentioned on `chrome://tracing/`
2129
 
 
2132
 
2133
  After zooming in, you can observe the flow of operations when calling `layer_norm` in this trace:
2134
 
2135
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2081.png)
2136
 
2137
  The sequence begins in the CPU (the upper section) with `aten::layer_norm`, progressing to `aten::native_layer_norm`, and then transitioning to `cudaLaunchKernel`. From there, we move on to the GPU, where the `vectorized_layer_norm_kernel` kernel is called.
2138
 
 
2153
 
2154
  and open the file `output.ncu-rep` with Nsight Compute, you will have a view that looks like this :
2155
 
2156
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2082.png)
2157
 
2158
  With clear warnings about compute and memory utilization, and how to make the kernel better in balancing compute and memory and achieve maximal occupancy.
2159
 
 
2235
 
2236
  The chain rule applies here since the loss (L) depends directly on the output (Y). This equation is telling us that to get the gradient of the loss with respect to our input (dL/dX), we multiply the gradient of the loss with respect to the output (dL/dY) by our weight matrix (W).
2237
 
2238
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2083.png)
2239
 
2240
  Likewise, we can use chain rule to compute the gradient w.r.t to the weight:
2241
 
 
2243
  \frac{dL}{dW} = \frac{dL}{dY} \frac{dY}{dW} = \frac{dL}{dY} X
2244
  $$
2245
 
2246
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2084.png)
2247
 
2248
  Here is a snippet of code to clarify all the concepts above:
2249
 
 
2322
  example_column_row_linear()
2323
  ```
2324
 
2325
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2085.png)
2326
 
2327
  **TODO** add these illustrations somewhere? I found them helpful:
2328
 
2329
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2086.png)
2330
 
2331
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2087.png)
2332
 
2333
  ## A3: ZeRO-R
2334
 
 
2473
 
2474
  ### Interconnect
2475
 
2476
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2088.png)
2477
 
2478
  ## How to profile your code
2479
 
 
2509
 
2510
  After running this code, you will find `*.trace.json` files under the `profiler_out_dir`. To visualize the results, the easiest way is to open Google Chrome, go to `chrome://tracing/`, and drag the file into it. This will allow you to view the profiling results. To get more details, we invite you to check out the amazing [**tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)** created by PyTorch.
2511
 
2512
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2089.png)
2513
 
2514
  ## Formulas for compute / comms the balanhe balance
2515
 
 
2622
 
2623
  ```
2624
 
2625
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2090.png)
2626
 
2627
  ## Integrating Context Parallelism with TP/SP
2628
 
 
2635
  3. **Replace standard attention with ring attention:** During the forward pass, each TP rank relies on the ring attention to compute the correct attention results during both the forward and backward passes. So all CP ranks within TP=0 for example need to all-gather the full KV sequence and calculate attention, but we store only the KV of a sequence chunk to reduce memory activations by CP.
2636
 
2637
  ![TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
2638
+ TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2091.png)
2639
 
2640
  TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1
2641
  TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank
 
2648
 
2649
  However, through extensive experimentation, we identified two effective training recipes that allowed us to **fully pretrain a 1B LLaMA model in FP8**, covering both the forward and backward passes, while using an FP8 optimizer. More importantly, our approach successfully matched LLaMA-2’s pretraining learning rate. The result?
2650
 
2651
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2092.png)
2652
 
2653
  A loss curve that perfectly matches mixed-precision bfloat16 (bfloat16 with FP32 master weights as the baseline). We successfully tested this to train a 1B LLaMA up to 100B tokens and a 7B LLaMA up to 25B tokens.
2654
 
 
2686
 
2687
  **Non-overlapping:** If we don't overlap the communication and computation, each computation (represented by the purple block) can only begin after the communication (green block) is complete and total time is the sum of communication and computation.
2688
 
2689
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2093.png)
2690
 
2691
  **Overlapping:** However, if we manage to launch communication and computation in parallel, we eliminate the waiting time! Now we can see that the computation (green block) is launched immediately, one after the other. In this case the total time is *only* the sum of computations.
2692
 
2693
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2094.png)
2694
 
2695
  Context parallelism has helped us going past the intra-node interconnect bottleneck, which blocked us from scaling TP across nodes. However, as you probably noted, it only helps reducing the memory constraints if the activation memory dominates the memory budget due to long sequences. What if we are not working on super long sequences and the model weights alone are too big for a single node?
2696
 
2697
  Well it turns out we have an other –quite different– option called pipeline parallelism (PP) which the time has come to explore now.
2698
 
2699
+ [TODO: comment from Nouamane on comms overlapping with DP 512]
2700
+
2701
+ ## seq parallel profiling
2702
+
2703
+ TODO: remove, Profiling:
2704
+
2705
+ - TP
2706
+
2707
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2095.png)
2708
+
2709
+ - Seq Parall
2710
+
2711
+ ![image.png](The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2096.png)
2712
+
2713
+ Allreduce takes almost double the duration (900us) of reducescatter and allgather (500us)
dist/bibliography.bib CHANGED
@@ -403,4 +403,13 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
403
  archivePrefix={arXiv},
404
  primaryClass={cs.LG},
405
  url={https://arxiv.org/abs/2205.05198},
 
 
 
 
 
 
 
 
 
406
  }
 
403
  archivePrefix={arXiv},
404
  primaryClass={cs.LG},
405
  url={https://arxiv.org/abs/2205.05198},
406
+ }
407
+ @misc{wang2024domino,
408
+ title={Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping},
409
+ author={Guanhua Wang and Chengming Zhang and Zheyu Shen and Ang Li and Olatunji Ruwase},
410
+ year={2024},
411
+ eprint={2409.15241},
412
+ archivePrefix={arXiv},
413
+ primaryClass={cs.DC},
414
+ url={https://arxiv.org/abs/2409.15241},
415
  }
dist/index.html CHANGED
@@ -468,16 +468,125 @@
468
 
469
  <h2>Data Parallelism</h2>
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  <h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
472
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  <h4><strong>Second optimization:</strong> Bucketing gradients</h4>
474
 
475
- <h4><strong>Third optimization: I</strong>nterplay with gradient accumulation</h4>
476
-
 
 
 
 
 
 
 
 
 
 
 
 
477
  <h3>Revisit global batch size</h3>
478
-
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  <h3>Our journey up to now</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
 
 
 
 
 
 
 
 
 
 
 
 
481
  <h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
482
 
483
  <h4>Memory usage revisited</h4>
@@ -488,12 +597,253 @@
488
 
489
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
490
 
 
491
  <h2>Tensor Parallelism</h2>
492
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  <h3>Tensor Parallelism in a Transformer Block</h3>
494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  <h3>Sequence Parallelism</h3>
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  <h2>Context Parallelism</h2>
498
 
499
  <h3>Introducing Context Parallelism</h3>
 
468
 
469
  <h2>Data Parallelism</h2>
470
 
471
+ <p>The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism. </p>
472
+
473
+ <p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
474
+
475
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
476
+
477
+ <p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
478
+
479
+ <aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].</aside>
480
+
481
+ <p>TODO: embed naive DP: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60</a></p>
482
+
483
+ <p>TODO: embed bucket DP: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171</a></p>
484
+
485
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
486
+
487
+ <p>A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is <strong>A BIG NO!</strong> Because we don’t want our GPUs to stay idle while communication is happening.</p>
488
+
489
+ <p>Instead we should try to overlap communication and computation whenever possible so that they happen at the same time as much as possible.</p>
490
+
491
+ <p>Let’s see three optimizations that are done in practice for this! </p>
492
+
493
  <h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
494
+
495
+ <p>The main drawback of the naive DDP approach we’ve just described is that after the backward pass (<em>computation</em>), we have to wait for gradient synchronization (<em>communication</em>) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!</p>
496
+
497
+ <p>As shown in the figure above, the gradients (red boxes) for a layer can be gathered and summed even before the gradients from earlier layers (red boxes to the left) have been computed. For example, as soon as the backward pass of the last layer is complete (last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.</p>
498
+
499
+ <p>This can be achieved in pytorch by attaching an <em>all-reduce hook function</em> to each parameter. An all-reduce operation is triggered as soon as the gradient for that parameter is ready, while gradients for other parameters are still being computed. This approach overlaps most of the all-reduce operations with gradient calculations, thereby improving efficiency. Here's a simple function to attach a hook:</p>
500
+
501
+ <d-code block language="python">
502
+ def register_backward_hook(self, hook):
503
+ """
504
+ Registers a backward hook for all parameters of the model that
505
+ require gradients.
506
+ """
507
+ for p in self.module.parameters():
508
+ if p.requires_grad is True:
509
+ p.register_post_accumulate_grad_hook(hook)</d-code>
510
+
511
+ <p><img alt="image.png" src="/assets/images/placeholder.png"/></p>
512
+
513
+ <p>Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. </p>
514
+
515
+ <p>This is our first example of “<em>overlapping computation and communication</em>” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency. Let's have a look how we can further improve the DP efficiency!</p>
516
+
517
+
518
  <h4><strong>Second optimization:</strong> Bucketing gradients</h4>
519
 
520
+ <p>We can even go further with optimizing DP. For a given number of parameters to synchronize, GPU operations like collective communications are often more efficient when performing few calls on large tensors rather than many calls on smaller tensors. Therefore, instead of performing independent all-reduce for each gradient, we can group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket. Think of it like packing items into boxes before shipping—it's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.</p>
521
+
522
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
523
+
524
+ <h4><strong>Third optimization: </strong>Interplay with gradient accumulation</h4>
525
+
526
+ <p>As we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with <code>optimizer.step()</code>. When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.</p>
527
+
528
+ <p>In a naive version, an all-reduce operation is automatically triggered after each backward pass during the accumulation, which is sub-optimal as a single reduce after the final step would have the same effect while reducing overhead.</p>
529
+
530
+ <p>In PyTorch, this is typically solved by adding a <a href="https://github.com/pytorch/pytorch/blob/5ea67778619c31b13644914deef709199052ee55/torch/nn/parallel/distributed.py#L1408-L1435"><code>model.no_sync()</code></a> decorator, which disables gradient synchronization, on the backward passes which don’t need reduction.</p>
531
+
532
+ <aside>When performing communication operations, tensors must be contiguous in memory. To avoid redundant memory copies during communication, ensure that tensors that will be communicated are stored contiguously in memory. Sometimes we need to allocate additional continuous buffers of the size of activations or model parameters specifically for communication, which contributes to the peak memory usage during training.</aside>
533
+
534
  <h3>Revisit global batch size</h3>
535
+ <p>Let’s update our batch size equation with our newly learned Data Parallelism and Gradient Accumulation parameters:</p>
536
+
537
+ <d-math block>
538
+ bs = gbs = mbs \times grad\_acc
539
+ </d-math>
540
+ <p>Where <d-math>grad\_acc</d-math> is the number of gradient accumulation steps and DP is the number of parallel instances used for data parallelism.</p>
541
+
542
+ <p>Given a targeted global batch size, we can thus trade gradient accumulation steps for data-parallel processes to speed up training. In practice, people tend to maximize the number of data-parallel nodes (DP) over gradient accumulation as much as possible since it's inherently parallel, unlike the sequential nature of gradient accumulation. Gradient accumulation is then added on top of data parallelism to achieve the target global batch size when scaling data parallelism alone is not sufficient before you run out of GPUs.</p>
543
+
544
+ <aside>A good resource for further reading on Data Parallelism is <a href="https://siboehm.com/articles/22/data-parallel-training">https://siboehm.com/articles/22/data-parallel-training</a>.
545
+ </aside>
546
+
547
+ <p>Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 4 more dimensions).</p>
548
+
549
  <h3>Our journey up to now</h3>
550
+ <p>Let’s quickly summarize what we’ve seen up to now and how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:</p>
551
+
552
+ <ol>
553
+ <li>We should first determine the best (global) batch size in tokens (<code>GBST</code>) either by consulting literature or running experiments measuring model convergence.</li>
554
+ <li>We then select a sequence length for training, again by either consulting literature or running experiments. Generally, 2-8k tokens work reliably well for the evaluations we have today (we won’t dive in training recipes here but teams usually increase the sequence at the end of the training, adding some longer-context data samples in the mix to reach the longer context size of today).</li>
555
+ <li>We now know the batch size (gbs). We can find the maximum local batch size (mbs) on a single GPU by increasing the local batch size until we run out of memory.</li>
556
+ <li>Finally, we determine the number of available GPUs for our target DP. The ratio of GBS to DP gives us the remaining number of gradient accumulation steps needed for the desired GBS. </li>
557
+ </ol>
558
+
559
+ <aside>For instance DeepSeek and Llama models are trained with a 4k tokens sequence length during the main pretraining phase.</aside>
560
+
561
+ <aside>The reason 2-8k work well for pretraining is that documents that are longer are very rare on the web. See this <a href="https://www.harmdevries.com/post/context-length/">Harm’s blogpost</a> for a detailed analysis.
562
+ </aside>
563
+
564
+ <p>If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs a.k.a GPU-rich 🤑 (!), we can either choose to not use all our GPUs, explore a larger global batch size or test if a lower MBS will speed up training. In the latter case we’ll end up prioritizing throughput over individual GPU compute efficiency, using a smaller MBS than possible in order to speed up training.</p>
565
+
566
+ <p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!</p>
567
+
568
+ <aside>Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by <em>ring latency</em> (time required for a signal to propagate once around the ring) **which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
569
+ </aside>
570
+
571
+ <p>While data parallelism cleverly overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. As we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly. The end result? We get less and less efficient returns from each additional GPU we add to the system:</p>
572
+
573
+ <p><img alt="image.png" src="/assets/images/placeholder.png"/></p>
574
+
575
+ <p>As expected, we can also see that the memory usage per GPU is not affected by adding more DP ranks for training.</p>
576
+
577
+ <p><strong>We’ve explored data parallelism, our first (simple) strategy to scale training across more GPUs. It works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!</strong></p>
578
 
579
+ <p>The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs<em>=1)</em> into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated: </p>
580
+
581
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
582
+
583
+ <aside>Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)</aside>
584
+
585
+ <p>Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some of these tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices!</p>
586
+
587
+ <p>There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined! The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!</p>
588
+
589
+
590
  <h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
591
 
592
  <h4>Memory usage revisited</h4>
 
597
 
598
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
599
 
600
+
601
  <h2>Tensor Parallelism</h2>
602
+
603
+ <p>So we have sharded the model’s parameters, gradients and optimizers states with ZeRO but we hit a limit once activation memory overtakes our memory budget. Welcome Tensor Parallelism (TP), a method which shards weights, gradients, and optimizers states as well as activations and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how Tensor Parallel works with simple matrix multiplications.</p>
604
+
605
+ <p>Tensor Parallelism leverages the mathematical properties of matrix multiplication <d-math>A \times B</d-math>. To understand how it works, let's examine two fundamental equations that make this parallelization possible:</p>
606
+
607
+ <d-math block>
608
+ \begin{aligned}
609
+ &\text{1.} \quad A\cdot B = A \cdot \begin{bmatrix} B_1 & B_2 & \cdots \end{bmatrix} = \begin{bmatrix} AB_1 & AB_2 & \cdots \end{bmatrix} \\
610
+ &\text{2.} \quad A\cdot B =\begin{bmatrix} A_1 & A_2 & \cdots \end{bmatrix} \begin{bmatrix} B_1 \\ B_2 \\ \vdots \end{bmatrix} = \sum_{i=1}^n A_i B_i
611
+ \end{aligned}
612
+ </d-math>
613
+
614
+ <p>This means that we can compute matrix product by either 1) multiplying each column of <d-math>B</d-math> individually or 2) multiplying each row individually and combining the results. In a neural network, the matrix multiplication is more often represented in the following format: <d-math>X \times W</d-math>, where:</p>
615
+
616
+ <ul>
617
+ <li>X represents the input or activation values</li>
618
+ <li>W represents the weight of the <code>nn.Linear</code></li>
619
+ </ul>
620
+
621
+ <p>In practice a small example of the operation looks like this:</p>
622
+
623
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
624
+
625
+ <p>Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.</p>
626
+
627
+ <p>Our first option is to use column-wise sharding (also called <strong><em>column-linear</em></strong>): We'll copy the complete input matrices to each worker, requiring an operation called <strong><em>broadcast</em></strong>, and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an <strong><em>all-gather</em></strong> operation.</p>
628
+
629
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
630
+
631
+ <p>The second option is called row-wise sharding (also called <strong><em>row-linear</em></strong>): As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, which needs a <strong><em>scatter</em></strong> operation rather than a broadcast as used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, thus requiring an all-reduce operation in this scenario.</p>
632
+
633
+ <p>We see here our fourth distributed primitive: <strong><em>scatter</em></strong>!</p>
634
+
635
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
636
+
637
  <h3>Tensor Parallelism in a Transformer Block</h3>
638
 
639
+ <p>To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks : Feedforward layers (MLP) and Multi-Head Attention (MHA). We can apply tensor parallelism to both.</p>
640
+
641
+ <p>The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks.</p>
642
+
643
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
644
+
645
+ <p>Now that we’ve found the most efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA).</p>
646
+
647
+ <p>We can generally follow a similar approach where Q, K, and V matrices are split in a column-parallel fashion, and the output projection is split along the row dimension. With multi-head attention, the column-parallel approach has a very natural interpretation: each worker computes the attention for an individual or a subset of heads. The same approach works as well for <a href="https://arxiv.org/abs/1911.02150"><strong><em>multi-query</em></strong> (MQA)</a> or <a href="https://arxiv.org/abs/2305.13245"><strong><em>grouped query attention</em></strong> (GQA)</a> where key and values are shared between queries. </p>
648
+
649
+ <p>It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.</p>
650
+
651
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
652
+
653
+ <p>Finally note that there is a tradeoff in terms of communication as we’ve added several distributed communication primitive directly in the computation path of our model. At the difference of ZeRO where we could prefetch, it can be harder to make these communication fully overlap with computations. </p>
654
+
655
+ <p><img alt="Forward pass in Tensor Parallelism" src="/assets/images/placeholder.png" /></p>
656
+
657
+ <p>Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This <em>exposed communication</em> overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied. </p>
658
+
659
+ <p>Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, it introduces significant communication requirements that heavily depend on the network infrastructure. The inability to hide this particular AllReduce behind computation means it directly adds to the critical path of forward propagation.</p>
660
+
661
+ <p><img alt="Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training." src="/assets/images/placeholder.png" /></p>
662
+
663
+ <p>Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.</p>
664
+
665
+ <p>In practice, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. As shown in the throughput plot above, we observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. This illustrates how communication costs can dominate at higher degrees of parallelism.</p>
666
+
667
+ <p>However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients, optimizer states and activations (to some extent) across GPUs. Let's examine this effect on a 70B parameter model:</p>
668
+
669
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
670
+
671
+ <p>As we can see, increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU. While tensor parallelism does help reduce activation memory in attention and feedforward layers by sharding the matrix multiplications across GPUs, we don't get the full memory benefits we could. This is because operations like layer normalization and dropout still require gathering the full activations on each GPU, partially negating the memory savings. We can do better by finding ways to parallelize these remaining operations as well.</p>
672
+
673
+ <aside>One interesting note about layer normalization in tensor parallel training - since each TP rank sees the same activations after the all-gather, the layer norm weights don't actually need an all-reduce to sync their gradients after the backward pass. They naturally stay in sync across ranks. However, for dropout operations, we must make sure to sync the random seed across TP ranks to maintain deterministic behavior.
674
+ </aside>
675
+
676
+ <p>This raises an interesting question - could we extend tensor parallelism to these remaining operations as well? Indeed, it's possible to parallelize layer norm, dropout and other operations too, which we'll explore next.</p>
677
+
678
  <h3>Sequence Parallelism</h3>
679
 
680
+ <p>In regions where we apply tensor parallelism (TP), like attention and feedforward layers, each GPU only needs to operate on a portion of the hidden dimension since the weights are sharded. However, operations like layer norm or dropout (which is not used a lot anymore in LLM) require access to the full hidden dimension to compute correctly.</p>
681
+
682
+ <p>Rather than gathering the full hidden dimension on each GPU (which would defeat the memory benefits of TP), we can instead shard these operations along the sequence length dimension. This approach is called <strong>sequence parallelism (SP)</strong>.</p>
683
+
684
+ <aside>Note that the term Sequence Parallelism is a bit overloaded: the Sequence Parallelism in this section is tightly coupled to Tensor Parallelism and applies to dropout and layer norm operation. However, when we will move to longer sequences the attention computation will become a bottleneck, which calls for techniques such as Ring-Attention, which are sometimes also called <em>Sequence Parallelism</em> but we’ll refer to them as <em>Context Parallelism</em> to differentiate the two approaches. So each time you see sequence parallelism, remember that it is used together with tensor parallelism (in contrast to context parallelism, which can be used independently).</aside>
685
+
686
+ <p>Sequence parallelism (SP) involves splitting the activations and computations for the parts of the model not handled by tensor parallelism (TP) such as Dropout and LayerNorm, but along the input sequence dimension rather than across hidden dimension. This is needed because these operations require access to the full hidden dimension to compute correctly. For example, LayerNorm needs the full hidden dimension to compute mean and variance:</p>
687
+
688
+ <d-math block>
689
+ \text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
690
+ </d-math>
691
+
692
+ <p>where <d-math>\mu = \text{mean}(x)</d-math> and <d-math>\sigma^2 = \text{var}(x)</d-math> are computed across hidden dimension <d-math>h</d-math>.</p>
693
+
694
+ <p>So even though these operations are computationally cheap, they still require significant activation memory since they need the complete hidden dimension. SP allows us to shard this <strong>memory</strong> burden across GPUs by splitting along the sequence dimension instead.</p>
695
+
696
+ <p>In practice we’ll go from the left diagram to the right:</p>
697
+
698
+ <p><img alt=" in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
699
+ in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
700
+ SP region needs full hidden_dim" src="/assets/images/placeholder.png" /></p>
701
+
702
+ <p>in forward: f = no-op ; f<em> = all-reduce ; g = all-gather ; g</em> = reduce-scatter in backward: f = all-reduce ; f<em> = no-op ; g = reduce-scatter ; g</em> = all-gather SP region needs full hidden_dim</p>
703
+
704
+ <p>The diagram shows how we transition between tensor-parallel and sequence-parallel regions using different collective operations (labeled "f" and "g"). The key challenge is managing these transitions efficiently while keeping memory usage low and maintaining correctness.</p>
705
+
706
+ <p>In the forward pass:</p>
707
+ <ul>
708
+ <li>"f" is a no-op (no operation) because activations are already duplicated across ranks</li>
709
+ <li>"f*" is an all-reduce to synchronize activations and ensure correctness</li>
710
+ </ul>
711
+ <p>In the backward pass:</p>
712
+ <ul>
713
+ <li>"f*" is a no-op because gradients are already duplicated across ranks</li>
714
+ <li>"f" is an all-reduce to synchronize gradients</li>
715
+ </ul>
716
+
717
+ <p>These operations "f" and "f<em>" are called </em><em>conjugate</em>* pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.</p>
718
+
719
+ <p>For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
720
+
721
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
722
+
723
+ <p>So what is actually happening here? As a famous LLM would say, let’s take it step-by-step:</p>
724
+
725
+ <p><strong>Initial LayerNorm (SP Region)</strong></p>
726
+ <ul>
727
+ <li>Input tensors X1<em> and X2</em> (b,s/2,h) enter LayerNorm, already split across sequence dimension</li>
728
+ <li>Each GPU computes LayerNorm independently on its sequence chunk and give Y1<em> and Y2</em></li>
729
+ </ul>
730
+ <p><strong>First Transition (SP → TP)</strong></p>
731
+ <ul>
732
+ <li>"g" operation (all-gather) combines Y1<em> and Y2</em> back to full sequence length</li>
733
+ <li> Restores Y (b,s,h) since column linear layer needs full hidden dimension h</li>
734
+ </ul>
735
+ <p><strong>First Linear Layer (TP Region)</strong></p>
736
+ <ul>
737
+ <li>A1 is a column-linear layer, so it splits Y along the hidden dimension</li>
738
+ <li>GeLU is applied independently on each GPU</li>
739
+ <li>Z1* is (b,s,h/2)</li>
740
+ </ul>
741
+ <p><strong>Second Linear Layer (TP Region)</strong></p>
742
+ <ul>
743
+ <li>B1 is a row-linear layer, so it restores the hidden dimension</li>
744
+ <li>W1 is (b,s,h)</li>
745
+ </ul>
746
+ <p><strong>Final Transition (TP → SP)</strong></p>
747
+ <ul>
748
+ <li>"g*" operation (reduce-scatter) which reduces for previous row-linear correctness while scattering along sequence dimension</li>
749
+ <li>W1* is (b,s/2,h)</li>
750
+ </ul>
751
+
752
+ <p>A key advantage of sequence parallelism is that it reduces the maximum activation size we need to store. In tensor parallelism alone, we had to store activations of shape (b,s,h) at various points. However, with sequence parallelism, the maximum activation size is reduced to <d-math>\frac{b \cdot s \cdot h}{tp}</d-math> since we always either split along the sequence or hidden dimensions.</p>
753
+
754
+ <p>It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP/SP - believe us, we find it hard to map as well so we made this small table to summarize how the activations (aka <code>hidden_states</code> ) shape change across hidden dimension h and sequence dimension s during a forward pass:</p>
755
+
756
+ <table>
757
+ <thead>
758
+ <tr>
759
+ <th>Region</th>
760
+ <th>TP only</th>
761
+ <th>TP with SP</th>
762
+ </tr>
763
+ </thead>
764
+ <tbody>
765
+ <tr>
766
+ <td>Enter TP (Column Linear)</td>
767
+ <td>h: sharded (weight_out is sharded)<br>s: full</td>
768
+ <td>h: sharded (weight_out is sharded)<br>s: <strong>all-gather</strong> to full</td>
769
+ </tr>
770
+ <tr>
771
+ <td>TP Region</td>
772
+ <td>h: sharded<br>s: full</td>
773
+ <td>h: sharded<br>s: full</td>
774
+ </tr>
775
+ <tr>
776
+ <td>Exit TP (Row Linear)</td>
777
+ <td>h: full (weight_out is full + <strong>all-reduce</strong> for correctness)<br>s: full</td>
778
+ <td>h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)<br>s: <strong>reduce-scatter</strong> to sharded</td>
779
+ </tr>
780
+ <tr>
781
+ <td>SP Region</td>
782
+ <td>h: full<br>s: full</td>
783
+ <td>h: full<br>s: sharded</td>
784
+ </tr>
785
+ </tbody>
786
+ </table>
787
+
788
+ <p>And for the embedding layer:</p>
789
+
790
+ <table>
791
+ <thead>
792
+ <tr>
793
+ <th>Region</th>
794
+ <th>Vanilla TP</th>
795
+ <th>TP with SP</th>
796
+ </tr>
797
+ </thead>
798
+ <tbody>
799
+ <tr>
800
+ <td>Embedding Layer (Row Linear sharded on vocab)</td>
801
+ <td>h: full (weight_out is full + <strong>all-reduce</strong> for correctness)<br>s: unchanged</td>
802
+ <td>h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)<br>s: <strong>reduce-scatter</strong> to sharded</td>
803
+ </tr>
804
+ </tbody>
805
+ </table>
806
+
807
+ <p>You can find an example of implementation of both column and row linear TP in picotron:
808
+
809
+ <a href="https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py">https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py</a> </p>
810
+
811
+ <p>By using sequence parallelism, we can achieve even greater activation memory savings, allowing us to push our batch size and sequence length further than what would be possible with tensor parallelism alone. Let's see what that means for our previous 70B model example:</p>
812
+
813
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
814
+
815
+ <p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
816
+
817
+ <p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
818
+
819
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
820
+
821
+ <p>Besides the fact that TP requires communications in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8).</p>
822
+
823
+
824
+ <aside>Overlapping communication with computation for TP is an active area of research, with recent work like Domino <d-cite bibtex-key="wang2024domino"></d-cite> exploring novel techniques to maximize this overlap. For example, Megatron-LM/Nanotron implement a partial overlapping of all-gather with FC1 computation, and we expect to see more innovations in this space as the field continues to evolve.</aside>
825
+
826
+ <p>As you might expect, this communication overhead becomes increasingly problematic as we scale up tensor parallelism. To illustrate this, let’s check throughput as we scale TP with SP for a 3B model:</p>
827
+
828
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
829
+ <p>Impact of combined Tensor and Sequence Parallelism (TP/SP) on a 3B model’s performance and memory utilization with 4096 seqlen: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.</p>
830
+
831
+ <p>Let’s summarize our observations:</p>
832
+
833
+ <ul>
834
+ <li>for both methods we notice the biggest performance drop when we move from TP=8 to TP=16, because that’s when we move from only communicating within a single node (NVLink), to communicating inter-nodes (EFA)</li>
835
+ <li>the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone</li>
836
+ <li>the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone</li>
837
+ </ul>
838
+
839
+ <p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
840
+
841
+ <p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
842
+
843
+ <aside>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to allreduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is a small communication overhead since LayerNorm has relatively few parameters.</aside>
844
+
845
+ <p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
846
+
847
  <h2>Context Parallelism</h2>
848
 
849
  <h3>Introducing Context Parallelism</h3>
src/bibliography.bib CHANGED
@@ -403,4 +403,13 @@ url = {https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md}
403
  archivePrefix={arXiv},
404
  primaryClass={cs.LG},
405
  url={https://arxiv.org/abs/2205.05198},
 
 
 
 
 
 
 
 
 
406
  }
 
403
  archivePrefix={arXiv},
404
  primaryClass={cs.LG},
405
  url={https://arxiv.org/abs/2205.05198},
406
+ }
407
+ @misc{wang2024domino,
408
+ title={Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping},
409
+ author={Guanhua Wang and Chengming Zhang and Zheyu Shen and Ang Li and Olatunji Ruwase},
410
+ year={2024},
411
+ eprint={2409.15241},
412
+ archivePrefix={arXiv},
413
+ primaryClass={cs.DC},
414
+ url={https://arxiv.org/abs/2409.15241},
415
  }
src/index.html CHANGED
@@ -468,16 +468,125 @@
468
 
469
  <h2>Data Parallelism</h2>
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  <h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
472
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  <h4><strong>Second optimization:</strong> Bucketing gradients</h4>
474
 
475
- <h4><strong>Third optimization: I</strong>nterplay with gradient accumulation</h4>
476
-
 
 
 
 
 
 
 
 
 
 
 
 
477
  <h3>Revisit global batch size</h3>
478
-
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  <h3>Our journey up to now</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
 
 
 
 
 
 
 
 
 
 
 
 
481
  <h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
482
 
483
  <h4>Memory usage revisited</h4>
@@ -488,12 +597,253 @@
488
 
489
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
490
 
 
491
  <h2>Tensor Parallelism</h2>
492
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  <h3>Tensor Parallelism in a Transformer Block</h3>
494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  <h3>Sequence Parallelism</h3>
496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  <h2>Context Parallelism</h2>
498
 
499
  <h3>Introducing Context Parallelism</h3>
 
468
 
469
  <h2>Data Parallelism</h2>
470
 
471
+ <p>The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism. </p>
472
+
473
+ <p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p>
474
+
475
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
476
+
477
+ <p>This involves our first “distributed communication” primitive: <em><strong>all-reduce</em></strong> which handles the synchronization and communication between GPU instances and nodes.</p>
478
+
479
+ <aside>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link].</aside>
480
+
481
+ <p>TODO: embed naive DP: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60</a></p>
482
+
483
+ <p>TODO: embed bucket DP: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171</a></p>
484
+
485
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
486
+
487
+ <p>A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is <strong>A BIG NO!</strong> Because we don’t want our GPUs to stay idle while communication is happening.</p>
488
+
489
+ <p>Instead we should try to overlap communication and computation whenever possible so that they happen at the same time as much as possible.</p>
490
+
491
+ <p>Let’s see three optimizations that are done in practice for this! </p>
492
+
493
  <h4><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h4>
494
+
495
+ <p>The main drawback of the naive DDP approach we’ve just described is that after the backward pass (<em>computation</em>), we have to wait for gradient synchronization (<em>communication</em>) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!</p>
496
+
497
+ <p>As shown in the figure above, the gradients (red boxes) for a layer can be gathered and summed even before the gradients from earlier layers (red boxes to the left) have been computed. For example, as soon as the backward pass of the last layer is complete (last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.</p>
498
+
499
+ <p>This can be achieved in pytorch by attaching an <em>all-reduce hook function</em> to each parameter. An all-reduce operation is triggered as soon as the gradient for that parameter is ready, while gradients for other parameters are still being computed. This approach overlaps most of the all-reduce operations with gradient calculations, thereby improving efficiency. Here's a simple function to attach a hook:</p>
500
+
501
+ <d-code block language="python">
502
+ def register_backward_hook(self, hook):
503
+ """
504
+ Registers a backward hook for all parameters of the model that
505
+ require gradients.
506
+ """
507
+ for p in self.module.parameters():
508
+ if p.requires_grad is True:
509
+ p.register_post_accumulate_grad_hook(hook)</d-code>
510
+
511
+ <p><img alt="image.png" src="/assets/images/placeholder.png"/></p>
512
+
513
+ <p>Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. </p>
514
+
515
+ <p>This is our first example of “<em>overlapping computation and communication</em>” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency. Let's have a look how we can further improve the DP efficiency!</p>
516
+
517
+
518
  <h4><strong>Second optimization:</strong> Bucketing gradients</h4>
519
 
520
+ <p>We can even go further with optimizing DP. For a given number of parameters to synchronize, GPU operations like collective communications are often more efficient when performing few calls on large tensors rather than many calls on smaller tensors. Therefore, instead of performing independent all-reduce for each gradient, we can group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket. Think of it like packing items into boxes before shipping—it's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.</p>
521
+
522
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
523
+
524
+ <h4><strong>Third optimization: </strong>Interplay with gradient accumulation</h4>
525
+
526
+ <p>As we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with <code>optimizer.step()</code>. When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.</p>
527
+
528
+ <p>In a naive version, an all-reduce operation is automatically triggered after each backward pass during the accumulation, which is sub-optimal as a single reduce after the final step would have the same effect while reducing overhead.</p>
529
+
530
+ <p>In PyTorch, this is typically solved by adding a <a href="https://github.com/pytorch/pytorch/blob/5ea67778619c31b13644914deef709199052ee55/torch/nn/parallel/distributed.py#L1408-L1435"><code>model.no_sync()</code></a> decorator, which disables gradient synchronization, on the backward passes which don’t need reduction.</p>
531
+
532
+ <aside>When performing communication operations, tensors must be contiguous in memory. To avoid redundant memory copies during communication, ensure that tensors that will be communicated are stored contiguously in memory. Sometimes we need to allocate additional continuous buffers of the size of activations or model parameters specifically for communication, which contributes to the peak memory usage during training.</aside>
533
+
534
  <h3>Revisit global batch size</h3>
535
+ <p>Let’s update our batch size equation with our newly learned Data Parallelism and Gradient Accumulation parameters:</p>
536
+
537
+ <d-math block>
538
+ bs = gbs = mbs \times grad\_acc
539
+ </d-math>
540
+ <p>Where <d-math>grad\_acc</d-math> is the number of gradient accumulation steps and DP is the number of parallel instances used for data parallelism.</p>
541
+
542
+ <p>Given a targeted global batch size, we can thus trade gradient accumulation steps for data-parallel processes to speed up training. In practice, people tend to maximize the number of data-parallel nodes (DP) over gradient accumulation as much as possible since it's inherently parallel, unlike the sequential nature of gradient accumulation. Gradient accumulation is then added on top of data parallelism to achieve the target global batch size when scaling data parallelism alone is not sufficient before you run out of GPUs.</p>
543
+
544
+ <aside>A good resource for further reading on Data Parallelism is <a href="https://siboehm.com/articles/22/data-parallel-training">https://siboehm.com/articles/22/data-parallel-training</a>.
545
+ </aside>
546
+
547
+ <p>Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 4 more dimensions).</p>
548
+
549
  <h3>Our journey up to now</h3>
550
+ <p>Let’s quickly summarize what we’ve seen up to now and how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:</p>
551
+
552
+ <ol>
553
+ <li>We should first determine the best (global) batch size in tokens (<code>GBST</code>) either by consulting literature or running experiments measuring model convergence.</li>
554
+ <li>We then select a sequence length for training, again by either consulting literature or running experiments. Generally, 2-8k tokens work reliably well for the evaluations we have today (we won’t dive in training recipes here but teams usually increase the sequence at the end of the training, adding some longer-context data samples in the mix to reach the longer context size of today).</li>
555
+ <li>We now know the batch size (gbs). We can find the maximum local batch size (mbs) on a single GPU by increasing the local batch size until we run out of memory.</li>
556
+ <li>Finally, we determine the number of available GPUs for our target DP. The ratio of GBS to DP gives us the remaining number of gradient accumulation steps needed for the desired GBS. </li>
557
+ </ol>
558
+
559
+ <aside>For instance DeepSeek and Llama models are trained with a 4k tokens sequence length during the main pretraining phase.</aside>
560
+
561
+ <aside>The reason 2-8k work well for pretraining is that documents that are longer are very rare on the web. See this <a href="https://www.harmdevries.com/post/context-length/">Harm’s blogpost</a> for a detailed analysis.
562
+ </aside>
563
+
564
+ <p>If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs a.k.a GPU-rich 🤑 (!), we can either choose to not use all our GPUs, explore a larger global batch size or test if a lower MBS will speed up training. In the latter case we’ll end up prioritizing throughput over individual GPU compute efficiency, using a smaller MBS than possible in order to speed up training.</p>
565
+
566
+ <p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!</p>
567
+
568
+ <aside>Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by <em>ring latency</em> (time required for a signal to propagate once around the ring) **which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on.
569
+ </aside>
570
+
571
+ <p>While data parallelism cleverly overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. As we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly. The end result? We get less and less efficient returns from each additional GPU we add to the system:</p>
572
+
573
+ <p><img alt="image.png" src="/assets/images/placeholder.png"/></p>
574
+
575
+ <p>As expected, we can also see that the memory usage per GPU is not affected by adding more DP ranks for training.</p>
576
+
577
+ <p><strong>We’ve explored data parallelism, our first (simple) strategy to scale training across more GPUs. It works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!</strong></p>
578
 
579
+ <p>The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs<em>=1)</em> into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated: </p>
580
+
581
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
582
+
583
+ <aside>Tip: you can quickly eyeball the minimal memory required for your model’s parameters by multiplying by 2 e.g. 70B → 140GB (=133GiB)</aside>
584
+
585
+ <p>Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some of these tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices!</p>
586
+
587
+ <p>There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined! The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!</p>
588
+
589
+
590
  <h3>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h3>
591
 
592
  <h4>Memory usage revisited</h4>
 
597
 
598
  <h4>ZeRO-3: Adding <strong>Parameter Partitioning</strong></h4>
599
 
600
+
601
  <h2>Tensor Parallelism</h2>
602
+
603
+ <p>So we have sharded the model’s parameters, gradients and optimizers states with ZeRO but we hit a limit once activation memory overtakes our memory budget. Welcome Tensor Parallelism (TP), a method which shards weights, gradients, and optimizers states as well as activations and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how Tensor Parallel works with simple matrix multiplications.</p>
604
+
605
+ <p>Tensor Parallelism leverages the mathematical properties of matrix multiplication <d-math>A \times B</d-math>. To understand how it works, let's examine two fundamental equations that make this parallelization possible:</p>
606
+
607
+ <d-math block>
608
+ \begin{aligned}
609
+ &\text{1.} \quad A\cdot B = A \cdot \begin{bmatrix} B_1 & B_2 & \cdots \end{bmatrix} = \begin{bmatrix} AB_1 & AB_2 & \cdots \end{bmatrix} \\
610
+ &\text{2.} \quad A\cdot B =\begin{bmatrix} A_1 & A_2 & \cdots \end{bmatrix} \begin{bmatrix} B_1 \\ B_2 \\ \vdots \end{bmatrix} = \sum_{i=1}^n A_i B_i
611
+ \end{aligned}
612
+ </d-math>
613
+
614
+ <p>This means that we can compute matrix product by either 1) multiplying each column of <d-math>B</d-math> individually or 2) multiplying each row individually and combining the results. In a neural network, the matrix multiplication is more often represented in the following format: <d-math>X \times W</d-math>, where:</p>
615
+
616
+ <ul>
617
+ <li>X represents the input or activation values</li>
618
+ <li>W represents the weight of the <code>nn.Linear</code></li>
619
+ </ul>
620
+
621
+ <p>In practice a small example of the operation looks like this:</p>
622
+
623
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
624
+
625
+ <p>Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.</p>
626
+
627
+ <p>Our first option is to use column-wise sharding (also called <strong><em>column-linear</em></strong>): We'll copy the complete input matrices to each worker, requiring an operation called <strong><em>broadcast</em></strong>, and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an <strong><em>all-gather</em></strong> operation.</p>
628
+
629
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
630
+
631
+ <p>The second option is called row-wise sharding (also called <strong><em>row-linear</em></strong>): As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, which needs a <strong><em>scatter</em></strong> operation rather than a broadcast as used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, thus requiring an all-reduce operation in this scenario.</p>
632
+
633
+ <p>We see here our fourth distributed primitive: <strong><em>scatter</em></strong>!</p>
634
+
635
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
636
+
637
  <h3>Tensor Parallelism in a Transformer Block</h3>
638
 
639
+ <p>To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks : Feedforward layers (MLP) and Multi-Head Attention (MHA). We can apply tensor parallelism to both.</p>
640
+
641
+ <p>The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks.</p>
642
+
643
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
644
+
645
+ <p>Now that we’ve found the most efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA).</p>
646
+
647
+ <p>We can generally follow a similar approach where Q, K, and V matrices are split in a column-parallel fashion, and the output projection is split along the row dimension. With multi-head attention, the column-parallel approach has a very natural interpretation: each worker computes the attention for an individual or a subset of heads. The same approach works as well for <a href="https://arxiv.org/abs/1911.02150"><strong><em>multi-query</em></strong> (MQA)</a> or <a href="https://arxiv.org/abs/2305.13245"><strong><em>grouped query attention</em></strong> (GQA)</a> where key and values are shared between queries. </p>
648
+
649
+ <p>It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.</p>
650
+
651
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
652
+
653
+ <p>Finally note that there is a tradeoff in terms of communication as we’ve added several distributed communication primitive directly in the computation path of our model. At the difference of ZeRO where we could prefetch, it can be harder to make these communication fully overlap with computations. </p>
654
+
655
+ <p><img alt="Forward pass in Tensor Parallelism" src="/assets/images/placeholder.png" /></p>
656
+
657
+ <p>Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This <em>exposed communication</em> overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied. </p>
658
+
659
+ <p>Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, it introduces significant communication requirements that heavily depend on the network infrastructure. The inability to hide this particular AllReduce behind computation means it directly adds to the critical path of forward propagation.</p>
660
+
661
+ <p><img alt="Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training." src="/assets/images/placeholder.png" /></p>
662
+
663
+ <p>Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.</p>
664
+
665
+ <p>In practice, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. As shown in the throughput plot above, we observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. This illustrates how communication costs can dominate at higher degrees of parallelism.</p>
666
+
667
+ <p>However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients, optimizer states and activations (to some extent) across GPUs. Let's examine this effect on a 70B parameter model:</p>
668
+
669
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
670
+
671
+ <p>As we can see, increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU. While tensor parallelism does help reduce activation memory in attention and feedforward layers by sharding the matrix multiplications across GPUs, we don't get the full memory benefits we could. This is because operations like layer normalization and dropout still require gathering the full activations on each GPU, partially negating the memory savings. We can do better by finding ways to parallelize these remaining operations as well.</p>
672
+
673
+ <aside>One interesting note about layer normalization in tensor parallel training - since each TP rank sees the same activations after the all-gather, the layer norm weights don't actually need an all-reduce to sync their gradients after the backward pass. They naturally stay in sync across ranks. However, for dropout operations, we must make sure to sync the random seed across TP ranks to maintain deterministic behavior.
674
+ </aside>
675
+
676
+ <p>This raises an interesting question - could we extend tensor parallelism to these remaining operations as well? Indeed, it's possible to parallelize layer norm, dropout and other operations too, which we'll explore next.</p>
677
+
678
  <h3>Sequence Parallelism</h3>
679
 
680
+ <p>In regions where we apply tensor parallelism (TP), like attention and feedforward layers, each GPU only needs to operate on a portion of the hidden dimension since the weights are sharded. However, operations like layer norm or dropout (which is not used a lot anymore in LLM) require access to the full hidden dimension to compute correctly.</p>
681
+
682
+ <p>Rather than gathering the full hidden dimension on each GPU (which would defeat the memory benefits of TP), we can instead shard these operations along the sequence length dimension. This approach is called <strong>sequence parallelism (SP)</strong>.</p>
683
+
684
+ <aside>Note that the term Sequence Parallelism is a bit overloaded: the Sequence Parallelism in this section is tightly coupled to Tensor Parallelism and applies to dropout and layer norm operation. However, when we will move to longer sequences the attention computation will become a bottleneck, which calls for techniques such as Ring-Attention, which are sometimes also called <em>Sequence Parallelism</em> but we’ll refer to them as <em>Context Parallelism</em> to differentiate the two approaches. So each time you see sequence parallelism, remember that it is used together with tensor parallelism (in contrast to context parallelism, which can be used independently).</aside>
685
+
686
+ <p>Sequence parallelism (SP) involves splitting the activations and computations for the parts of the model not handled by tensor parallelism (TP) such as Dropout and LayerNorm, but along the input sequence dimension rather than across hidden dimension. This is needed because these operations require access to the full hidden dimension to compute correctly. For example, LayerNorm needs the full hidden dimension to compute mean and variance:</p>
687
+
688
+ <d-math block>
689
+ \text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
690
+ </d-math>
691
+
692
+ <p>where <d-math>\mu = \text{mean}(x)</d-math> and <d-math>\sigma^2 = \text{var}(x)</d-math> are computed across hidden dimension <d-math>h</d-math>.</p>
693
+
694
+ <p>So even though these operations are computationally cheap, they still require significant activation memory since they need the complete hidden dimension. SP allows us to shard this <strong>memory</strong> burden across GPUs by splitting along the sequence dimension instead.</p>
695
+
696
+ <p>In practice we’ll go from the left diagram to the right:</p>
697
+
698
+ <p><img alt=" in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
699
+ in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
700
+ SP region needs full hidden_dim" src="/assets/images/placeholder.png" /></p>
701
+
702
+ <p>in forward: f = no-op ; f<em> = all-reduce ; g = all-gather ; g</em> = reduce-scatter in backward: f = all-reduce ; f<em> = no-op ; g = reduce-scatter ; g</em> = all-gather SP region needs full hidden_dim</p>
703
+
704
+ <p>The diagram shows how we transition between tensor-parallel and sequence-parallel regions using different collective operations (labeled "f" and "g"). The key challenge is managing these transitions efficiently while keeping memory usage low and maintaining correctness.</p>
705
+
706
+ <p>In the forward pass:</p>
707
+ <ul>
708
+ <li>"f" is a no-op (no operation) because activations are already duplicated across ranks</li>
709
+ <li>"f*" is an all-reduce to synchronize activations and ensure correctness</li>
710
+ </ul>
711
+ <p>In the backward pass:</p>
712
+ <ul>
713
+ <li>"f*" is a no-op because gradients are already duplicated across ranks</li>
714
+ <li>"f" is an all-reduce to synchronize gradients</li>
715
+ </ul>
716
+
717
+ <p>These operations "f" and "f<em>" are called </em><em>conjugate</em>* pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.</p>
718
+
719
+ <p>For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p>
720
+
721
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
722
+
723
+ <p>So what is actually happening here? As a famous LLM would say, let’s take it step-by-step:</p>
724
+
725
+ <p><strong>Initial LayerNorm (SP Region)</strong></p>
726
+ <ul>
727
+ <li>Input tensors X1<em> and X2</em> (b,s/2,h) enter LayerNorm, already split across sequence dimension</li>
728
+ <li>Each GPU computes LayerNorm independently on its sequence chunk and give Y1<em> and Y2</em></li>
729
+ </ul>
730
+ <p><strong>First Transition (SP → TP)</strong></p>
731
+ <ul>
732
+ <li>"g" operation (all-gather) combines Y1<em> and Y2</em> back to full sequence length</li>
733
+ <li> Restores Y (b,s,h) since column linear layer needs full hidden dimension h</li>
734
+ </ul>
735
+ <p><strong>First Linear Layer (TP Region)</strong></p>
736
+ <ul>
737
+ <li>A1 is a column-linear layer, so it splits Y along the hidden dimension</li>
738
+ <li>GeLU is applied independently on each GPU</li>
739
+ <li>Z1* is (b,s,h/2)</li>
740
+ </ul>
741
+ <p><strong>Second Linear Layer (TP Region)</strong></p>
742
+ <ul>
743
+ <li>B1 is a row-linear layer, so it restores the hidden dimension</li>
744
+ <li>W1 is (b,s,h)</li>
745
+ </ul>
746
+ <p><strong>Final Transition (TP → SP)</strong></p>
747
+ <ul>
748
+ <li>"g*" operation (reduce-scatter) which reduces for previous row-linear correctness while scattering along sequence dimension</li>
749
+ <li>W1* is (b,s/2,h)</li>
750
+ </ul>
751
+
752
+ <p>A key advantage of sequence parallelism is that it reduces the maximum activation size we need to store. In tensor parallelism alone, we had to store activations of shape (b,s,h) at various points. However, with sequence parallelism, the maximum activation size is reduced to <d-math>\frac{b \cdot s \cdot h}{tp}</d-math> since we always either split along the sequence or hidden dimensions.</p>
753
+
754
+ <p>It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP/SP - believe us, we find it hard to map as well so we made this small table to summarize how the activations (aka <code>hidden_states</code> ) shape change across hidden dimension h and sequence dimension s during a forward pass:</p>
755
+
756
+ <table>
757
+ <thead>
758
+ <tr>
759
+ <th>Region</th>
760
+ <th>TP only</th>
761
+ <th>TP with SP</th>
762
+ </tr>
763
+ </thead>
764
+ <tbody>
765
+ <tr>
766
+ <td>Enter TP (Column Linear)</td>
767
+ <td>h: sharded (weight_out is sharded)<br>s: full</td>
768
+ <td>h: sharded (weight_out is sharded)<br>s: <strong>all-gather</strong> to full</td>
769
+ </tr>
770
+ <tr>
771
+ <td>TP Region</td>
772
+ <td>h: sharded<br>s: full</td>
773
+ <td>h: sharded<br>s: full</td>
774
+ </tr>
775
+ <tr>
776
+ <td>Exit TP (Row Linear)</td>
777
+ <td>h: full (weight_out is full + <strong>all-reduce</strong> for correctness)<br>s: full</td>
778
+ <td>h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)<br>s: <strong>reduce-scatter</strong> to sharded</td>
779
+ </tr>
780
+ <tr>
781
+ <td>SP Region</td>
782
+ <td>h: full<br>s: full</td>
783
+ <td>h: full<br>s: sharded</td>
784
+ </tr>
785
+ </tbody>
786
+ </table>
787
+
788
+ <p>And for the embedding layer:</p>
789
+
790
+ <table>
791
+ <thead>
792
+ <tr>
793
+ <th>Region</th>
794
+ <th>Vanilla TP</th>
795
+ <th>TP with SP</th>
796
+ </tr>
797
+ </thead>
798
+ <tbody>
799
+ <tr>
800
+ <td>Embedding Layer (Row Linear sharded on vocab)</td>
801
+ <td>h: full (weight_out is full + <strong>all-reduce</strong> for correctness)<br>s: unchanged</td>
802
+ <td>h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness)<br>s: <strong>reduce-scatter</strong> to sharded</td>
803
+ </tr>
804
+ </tbody>
805
+ </table>
806
+
807
+ <p>You can find an example of implementation of both column and row linear TP in picotron:
808
+
809
+ <a href="https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py">https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py</a> </p>
810
+
811
+ <p>By using sequence parallelism, we can achieve even greater activation memory savings, allowing us to push our batch size and sequence length further than what would be possible with tensor parallelism alone. Let's see what that means for our previous 70B model example:</p>
812
+
813
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
814
+
815
+ <p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p>
816
+
817
+ <p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:</p>
818
+
819
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
820
+
821
+ <p>Besides the fact that TP requires communications in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8).</p>
822
+
823
+
824
+ <aside>Overlapping communication with computation for TP is an active area of research, with recent work like Domino <d-cite bibtex-key="wang2024domino"></d-cite> exploring novel techniques to maximize this overlap. For example, Megatron-LM/Nanotron implement a partial overlapping of all-gather with FC1 computation, and we expect to see more innovations in this space as the field continues to evolve.</aside>
825
+
826
+ <p>As you might expect, this communication overhead becomes increasingly problematic as we scale up tensor parallelism. To illustrate this, let’s check throughput as we scale TP with SP for a 3B model:</p>
827
+
828
+ <p><img alt="image.png" src="/assets/images/placeholder.png" /></p>
829
+ <p>Impact of combined Tensor and Sequence Parallelism (TP/SP) on a 3B model’s performance and memory utilization with 4096 seqlen: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.</p>
830
+
831
+ <p>Let’s summarize our observations:</p>
832
+
833
+ <ul>
834
+ <li>for both methods we notice the biggest performance drop when we move from TP=8 to TP=16, because that’s when we move from only communicating within a single node (NVLink), to communicating inter-nodes (EFA)</li>
835
+ <li>the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone</li>
836
+ <li>the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone</li>
837
+ </ul>
838
+
839
+ <p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p>
840
+
841
+ <p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.</p>
842
+
843
+ <aside>Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to allreduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is a small communication overhead since LayerNorm has relatively few parameters.</aside>
844
+
845
+ <p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!</p>
846
+
847
  <h2>Context Parallelism</h2>
848
 
849
  <h3>Introducing Context Parallelism</h3>