# Saving Memory Using Padding-Free Transformer Layers during Finetuning

For long sequence training, attention computation can become a memory bottleneck as a naive implementation requires $O({N}^{2})O(N^2)$ memory where $NN$ is the sequence length. However, recently FlashAttention [1,2] has been proposed, which optimizes IO and use online softmax [3] to reduce both data movement [4] from the GPU memory (typically HBM for datacenter GPUs) and GPU cache. The FlashAttention algorithm also reduces the memory requirement for attention computation from $O({N}^{2})O(N^2)$ to $O(N)O(N)$, i.e from quadratic to linear.

The current integration of FlashAttention [1,2] for training in most libraries (including HuggingFace transformers) are based on non-intrusive changes to these libraries and most implementations just replace the Naive Attention module with FlashAttention [1,2]. Though this is quite easy to implement, its suboptimal when using sequences of variable lengths in a batch (i.e when we have padding in the batch). All operations except attention are applied independently to each token position in the transformer. And since FlashAttention [1,2] completely avoids any computation/memory requirements on pad tokens, its possible to drop all reduntant computations and memory needed for padding tokens from the transformer model and essentially create a Padding-Free transformer model when using FlashAttention. This is also done in the original FlashAttention training codebase. It should be noted that this is an exact implementation of the model and has no approximations.

Similar optimizations are done in HuggingFace TGI for inference efficiency. It should be noted that this would not be a problem in cases where padding of the batch is not needed i.e if a batch has all examples of equal length or when using dense packing of examples (as is the case for pretraining models).

In this blog, we give the theoretical memory consumptions for naive attention, FlashAttention with padded transformer blocks (current implementation in HuggingFace transformers library) and the Padding-Free transformer blocks.

Lets assume an input batch of embeddings of shape $(b,max\{{s}_{i}\},h)(b,\; max\backslash \{s\_i\backslash \},\; h)$ as input to a transformer layer where $bb$, ${s}_{i}s\_i$ and $hh$ denote the batch size, unpadded sequence length of the ${i}^{th}i^\{th\}$ example in the batch and the hidden size for the transformer model respectively. For training the model, each transformer layer needs to cache activations of each operation (computed in the forward pass) for the backward pass. We assume 16-bit precision for training (2 bytes per value in a tensor). We also assume Multi-Head Attention [6] with $aa$ attention heads here for simplicity. Though same idea also applies to Multi-Query Attention [7] and Grouped-Query Attention [8].

## Naive Attention

The input LayerNorm receives an input of shape $(b,max\{{s}_{i}\},h)(b,\; max\backslash \{s\_i\backslash \},\; h)$ which needs to be cached for backward pass. The mean and variance also need to be cached which are each of shape $(b,max\{{s}_{i}\})(b,\; max\backslash \{s\_i\backslash \})$. Since, $\mathrm{max}\{{s}_{i}\}bh\gg 2bh\backslash max\backslash \{s\_i\backslash \}bh\; \backslash gg\; 2bh$, we can ignore the $2bh2bh$ elements for mean and variance. Total activation memory for this operation is $$2\times max\{{s}_{i}\}bh\phantom{\rule{1em}{0ex}}\text{bytes.}2\; \backslash times\; max\backslash \{s\_i\backslash \}bh\; \backslash text\{\backslash hspace\{1em\}bytes.\}$$

The input (shared among the Q, K and V projections) to the QKV projection matrices needs to be cached. It also has $max\{{s}_{i}\}bhmax\backslash \{s\_i\backslash \}bh$ elements taking $2\times max\{{s}_{i}\}bh2\; \backslash times\; max\backslash \{s\_i\backslash \}bh$ bytes. The outputs of each of Q, K and V projection also need to be cached, each of which has $max\{{s}_{i}\}bhmax\backslash \{s\_i\backslash \}bh$ elements taking $2\times max\{{s}_{i}\}bh2\; \backslash times\; max\backslash \{s\_i\backslash \}bh$ bytes each. Total activation memory for this operation is $$2\times max\{{s}_{i}\}bh+3\times 2\times max\{{s}_{i}\}bh=8\times max\{{s}_{i}\}bh\phantom{\rule{1em}{0ex}}\text{bytes.}2\; \backslash times\; max\backslash \{s\_i\backslash \}bh\; +\; 3\; \backslash times\; 2\; \backslash times\; max\backslash \{s\_i\backslash \}bh\; =\; 8\; \backslash times\; max\backslash \{s\_i\backslash \}bh\; \backslash text\{\backslash hspace\{1em\}bytes.\}$$

The output of softmax which has $(max\{{s}_{i}\}{)}^{2}ab(max\backslash \{s\_i\backslash \})^2ab$ elements also needs to be cached. Total activation memory is for this operation is $$2\times (max\{{s}_{i}\}{)}^{2}ab\phantom{\rule{1em}{0ex}}\text{bytes.}2\; \backslash times\; (max\backslash \{s\_i\backslash \})^2ab\; \backslash text\{\backslash hspace\{1em\}bytes.\}$$

Attention softmax has a dropout which requires saving a mask of $(max\{{s}_{i}\}{)}^{2}ab(max\backslash \{s\_i\backslash \})^2ab$ elements. Each element takes a byte since PyTorch doesn't allow bit tensors. The reason for this is probably an ease of implementation since GPUs are generally byte-addressable. Total memory for this operation is $$(max\{{s}_{i}\}{)}^{2}ab\phantom{\rule{1em}{0ex}}\text{bytes.}(max\backslash \{s\_i\backslash \})^2ab\; \backslash text\{\backslash hspace\{1em\}bytes.\}$$

The softmax dropout output has $(max\{{s}_{i}\}{)}^{2}ab(max\backslash \{s\_i\backslash \})^2ab$ elements which also needs to be cached. Total activation memory for this operation is $$2\times (max\{{s}_{i}\}{)}^{2}ab\phantom{\rule{1em}{0ex}}\text{bytes.}2\; \backslash times\; (max\backslash \{s\_i\backslash \})^2ab\; \backslash text\{\backslash hspace\{1em\}bytes.\}$$

We cache the output of the above multiplication which is the input to the projection matrix. It has $max\{{s}_{i}\}bhmax\backslash \{s\_i\backslash \}bh$ elements. Total activation memory for this operation is $$2\times max\{{s}_{i}\}bh\phantom{\rule{1em}{0ex}}\text{bytes.}2\; \backslash times\; max\backslash \{s\_i\backslash \}bh\; \backslash text\{\backslash hspace\{1em\}bytes.\}$$

Only the dropout mask needs to be cached. Total memory for this operation is $$max\{{s}_{i}\}bh\phantom{\rule{1em}{0ex}}\text{bytes.}max\backslash \{s\_i\backslash \}bh\; \backslash text\{\backslash hspace\{1em\}bytes.\}$$

Same as the previous layernorm. Memory requirement is $$2\times max\{{s}_{i}\}bh\phantom{\rule{1em}{0ex}}\text{bytes.}2\; \backslash times\; max\backslash \{s\_i\backslash \}bh\; \backslash text\{\backslash hspace\{1em\}bytes.\}$$

We assume here that the feedforward hidden dimension is $f=4hf\; =\; 4h$ as is typical for a standard transformer. Inputs to each linear layer and the input to GELU activation function needs to be cached. These take $2\times max\{{s}_{i}\}bh2\; \backslash times\; max\backslash \{s\_i\backslash \}bh$, $8\times max\{{s}_{i}\}bh8\; \backslash times\; max\backslash \{s\_i\backslash \}bh$ bytes and $8\times max\{{s}_{i}\}bh8\; \backslash times\; max\backslash \{s\_i\backslash \}bh$ bytes respectively. The required memory for the MLP block is $$18\times max\{{s}_{i}\}bh\phantom{\rule{1em}{0ex}}\text{bytes.}18\; \backslash times\; max\backslash \{s\_i\backslash \}bh\; \backslash text\{\backslash hspace\{1em\}bytes.\}$$

Memory required is same as point (8) above i.e $$max\{{s}_{i}\}bh\phantom{\rule{1em}{0ex}}\text{bytes.}max\backslash \{s\_i\backslash \}bh\; \backslash text\{\backslash hspace\{1em\}bytes.\}$$

Summing these up, total activation memory per layer is given by: $${M}_{naive}=max\{{s}_{i}\}bh(34+\frac{5a\times max\{{s}_{i}\}}{h})M\_\{naive\}\; =\; max\backslash \{s\_i\backslash \}bh\; \backslash left(34\; +\; \backslash frac\{5a\; \backslash times\; max\backslash \{s\_i\backslash \}\; \}\{h\}\; \backslash right)$$

## Transformer Layer with FlashAttention

FlashAttention [1,2] has been integrated into the HuggingFace transformers API. The current implementation at the time of writing this blog does an unpad operation just before FlashAttention kernel is executed. This operation converts the input Q, K, V of shape $(b,max\{{s}_{i}\},h)(b,\; max\backslash \{s\_i\backslash \},\; h)$ to shape $(\sum {s}_{i},h)\backslash left(\backslash sum\; s\_i,\; h\backslash right)$ (where each example in the batch is concatenated one after the other resulting in a 2D tensor) and launches the FlashAttention kernel. Post attention computation, the output is padded again to the shape $(b,max\{{s}_{i}\},h)(b,\; max\backslash \{s\_i\backslash \},\; h)$.

FlashAttention [1,2] avoids materializing the $Q{K}^{T}QK^T$ quadratic matrix in memory and uses online softmax [3], thereby dropping the need to cache activations in point (3). Rather we only need to materialize the output matrix which has shape $(\sum {s}_{i},a,\frac{h}{a})\backslash left(\backslash sum\; s\_i,\; a,\; \backslash frac\{h\}\{a\}\backslash right)$, the 2 softmax statistics both of which have the same shape $(\sum {s}_{i},a)\backslash left(\backslash sum\; s\_i,\; a\backslash right)$ and the random number generator state for the dropout which we ignore here. For the algorithm in detail, refer to FlashAttention [1,2] paper. We also need to cache the attention mask of booleans which is used for padding and unpadding. We ignore it in calculations though since its same for every layer and can be cached once for the entire transformer model and doesn't need to be cached on every layer. Thus the memory required for attention becomes $$2\sum {s}_{i}a\left(\frac{h}{a}\right)+4\sum {s}_{i}a=2\sum {s}_{i}(h+2a)\phantom{\rule{1em}{0ex}}\text{bytes.}2\; \backslash sum\; s\_i\; a\backslash left(\backslash frac\{h\}\{a\}\backslash right)\; +\; 4\; \backslash sum\; s\_i\; a\; =\; 2\; \backslash sum\; s\_i\; (h\; +\; 2a)\; \backslash text\{\backslash hspace\{1em\}bytes.\}$$

Thus we have the total activation memory per layer with FlashAttention [1,2] as follows: $${M}_{flash}=max\{{s}_{i}\}bh(34+\frac{\sum {s}_{i}}{max\{{s}_{i}\}b}[1+\frac{2a}{h}])M\_\{flash\}\; =\; max\backslash \{s\_i\backslash \}bh\; \backslash left(34\; +\; \backslash frac\{\backslash sum\; s\_i\}\{max\backslash \{s\_i\backslash \}b\}\; \backslash left[1\; +\; \backslash frac\{2a\}\{h\}\backslash right]\; \backslash right)$$

## Padding-Free Transformer Layer

Since all operations (except attention) in the transformer layer are same for each token position, we can avoid the padding and unpadding operation and thus reduce the activation memory required by the transformer layer further, this requires minor changes to the HuggingFace transformers implementation. In this implementation of the transformer, there is no wasted memory for pad token positions at all! In this case, the input to the entire transformer model is of the shape $(\sum {s}_{i},h)\backslash left(\backslash sum\; s\_i,\; h\backslash right)$. The memory in this case is given by $${M}_{padding\mathrm{\_}free}=(\sum {s}_{i})h(35+\frac{2a}{h})M\_\{padding\backslash \_free\}\; =\; \backslash left(\; \backslash sum\; s\_i\; \backslash right)\; h\; \backslash left(35\; +\; \backslash frac\{2a\}\{h\}\; \backslash right)$$

It should be noted that ${M}_{flash}={M}_{padding\mathrm{\_}free}M\_\{flash\}\; =\; M\_\{padding\backslash \_free\}$ when there is no padding i.e when ${s}_{i}=\mathrm{max}\{{s}_{i}\}\mathrm{\forall}i\in \{1,2,\mathrm{.}\mathrm{.}\mathrm{.},b\}s\_i\; =\; \backslash max\backslash \{s\_i\backslash \}\; \backslash forall\; i\; \backslash in\; \backslash \{1,\; 2,\; ...,\; b\backslash \}$. This optimization is similar to running a transformer model with nested tensors. While there has been significant effort to resolve this problem by taking approches like binning examples by context lengths, these lead to model performance degradation especially during finetuning.

## Motivation for using Padding-Free Transformer Layer

Now, we analyze the memory consumptions in the 3 transformer layer implementations. We assume that we have a dataset of sequences of lengths following a discrete uniform distribution i.e ${S}_{i}\sim U\{1,2,\mathrm{.}\mathrm{.}\mathrm{.},N\}S\_i\; \backslash sim\; U\backslash \{1,\; 2,\; ...,\; N\backslash \}$, where ${S}_{i}S\_i$ is the random variable denoting the sequence length of ${i}^{th}i^\{th\}$ sample in the batch and $NN$ is the maximum sequence length for the dataset and the model. We sample batches with $bb$ examples each, with sequences of lengths $({S}_{1},{S}_{2},\mathrm{.}\mathrm{.}\mathrm{.},{S}_{b})(S\_1,\; S\_2,\; ...,\; S\_b)$. We compute the expectation $\mathbb{E}[{M}_{naive}]\backslash mathbb\{E\}[M\_\{naive\}]$, $\mathbb{E}[{M}_{flash}]\backslash mathbb\{E\}[M\_\{flash\}]$ and $\mathbb{E}[{M}_{padding\mathrm{\_}free}]\backslash mathbb\{E\}[M\_\{padding\backslash \_free\}]$ under the discrete uniform distribution. To do so, we consider another random variable $K=max\{{S}_{i}\}K\; =\; max\backslash \{S\_i\backslash \}$. The Cumulative Distribution Function for $KK$ can be derived as: $$P(K\le k)=P(max\{{S}_{i}\}\le k)=P({S}_{1}\le k,{S}_{2}\le k,\mathrm{.}\mathrm{.}\mathrm{.},{S}_{b}\le k)P(K\; \backslash le\; k)\; =\; P(max\backslash \{S\_i\backslash \}\; \backslash le\; k)\; =\; P(S\_1\; \backslash le\; k,\; S\_2\; \backslash le\; k,\; ...,\; S\_b\; \backslash le\; k)$$ Now, using the fact that examples in a batch are i.i.d, we have $$\text{\hspace{0.25em}\hspace{0.05em}}\u27f9\text{\hspace{0.25em}\hspace{0.05em}}P(K\le k)=[P({S}_{i}\le k){]}^{b}={\left(\frac{k}{N}\right)}^{b}\backslash implies\; P(K\; \backslash le\; k)\; =\; [P(S\_i\; \backslash le\; k)]\; ^\; b\; =\; \backslash left(\; \backslash frac\{k\}\{N\}\; \backslash right)\; ^\; b$$ and thus we have the Probability Mass Function for $KK$ as: $$P(K=k)=P(K\le k)-P(K\le k-1)={\left(\frac{k}{N}\right)}^{b}-{\left(\frac{k-1}{N}\right)}^{b}P(K\; =\; k)\; =\; P(K\; \backslash le\; k)\; -\; P(K\; \backslash le\; k\; -\; 1)\; =\; \backslash left(\backslash frac\{k\}\{N\}\backslash right)\; ^\; b\; -\; \backslash left(\backslash frac\{k\; -\; 1\}\{N\}\backslash right)\; ^\; b$$ We can use computational methods or Faulhaber's formula [9] with the aforementioned derived result to calculate the expectations of the memory usage in the 3 methods. We report the theoretical memory consumption derived using the equations for a 20B parameter model in the following table. We find that using a Padding-Free version of the transformer layer saves $\sim 43\mathrm{\%}\backslash sim43\backslash \%$ activation memory and also saves a lot of redundant FLOPs. We leave the analysis of FLOPs out of this blog but they are easily derivable.

Sequence Length | Naive Attention | Flash Attention | Padding-Free Transformer |
---|---|---|---|

512 | 1.085 GB | 0.721 GB | 0.411 GB |

1024 | 2.919 GB | 1.441 GB | 0.821 GB |

2048 | 8.837 GB | 2.882 GB | 1.642 GB |

4096 | 29.674 GB | 5.763 GB | 3.283 GB |

8192 | 107.347 GB | 11.524 GB | 6.566 GB |

16384 | 406.693 GB | 23.048 GB | 13.132 GB |

32768 | 1581.386 GB | 46.096 GB | 26.263 GB |

Table: Memory usage per transformer layer for different attention implementations at different context lengths for a 20B parameter model with context length $(N=8192)(N\; =\; 8192)$, hidden size $(h=6144)(h\; =\; 6144)$, FFN hidden size $(f=24576)(f\; =\; 24576)$, attention heads $(a=48)(a\; =\; 48)$.

## Conclusion

In this blog, we present a way to completely avoid computations and memory requirements of pad tokens during finetuning of transformer models using FlashAttention. Our changes are easily integrable into the HuggingFace transformers ecosystem for finetuning. We also derive equations for theoretical memory consumption for the same in this blog. The method doesn't involve writing any low level device code. The only non-native PyTorch code we use is FlashAttention which is already available.

## References

- Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." Advances in Neural Information Processing Systems 35 (2022): 16344-16359.
- Dao, Tri. "Flashattention-2: Faster attention with better parallelism and work partitioning." arXiv preprint arXiv:2307.08691 (2023).
- Milakov, Maxim, and Natalia Gimelshein. "Online normalizer calculation for softmax." arXiv preprint arXiv:1805.02867 (2018).
- Ivanov, Andrei, et al. "Data movement is all you need: A case study on optimizing transformers." Proceedings of Machine Learning and Systems 3 (2021): 711-732.
- Korthikanti, Vijay Anand, et al. "Reducing activation recomputation in large transformer models." Proceedings of Machine Learning and Systems 5 (2023).
- Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).
- Shazeer, Noam. "Fast transformer decoding: One write-head is all you need." arXiv preprint arXiv:1911.02150 (2019).
- Ainslie, Joshua, et al. "Gqa: Training generalized multi-query transformer models from multi-head checkpoints." arXiv preprint arXiv:2305.13245 (2023).
- Knuth, Donald E. "Johann Faulhaber and sums of powers." Mathematics of Computation 61.203 (1993): 277-294.