Saving Memory Using Padding-Free Transformer Layers during Finetuning

Community Article Published March 9, 2024

For long sequence training, attention computation can become a memory bottleneck as a naive implementation requires O(N2)O(N^2) memory where NN is the sequence length. However, recently FlashAttention [1,2] has been proposed, which optimizes IO and use online softmax [3] to reduce both data movement [4] from the GPU memory (typically HBM for datacenter GPUs) and GPU cache. The FlashAttention algorithm also reduces the memory requirement for attention computation from O(N2)O(N^2) to O(N)O(N), i.e from quadratic to linear.

The current integration of FlashAttention [1,2] for training in most libraries (including HuggingFace transformers) are based on non-intrusive changes to these libraries and most implementations just replace the Naive Attention module with FlashAttention [1,2]. Though this is quite easy to implement, its suboptimal when using sequences of variable lengths in a batch (i.e when we have padding in the batch). All operations except attention are applied independently to each token position in the transformer. And since FlashAttention [1,2] completely avoids any computation/memory requirements on pad tokens, its possible to drop all reduntant computations and memory needed for padding tokens from the transformer model and essentially create a Padding-Free transformer model when using FlashAttention. This is also done in the original FlashAttention training codebase. It should be noted that this is an exact implementation of the model and has no approximations.

Similar optimizations are done in HuggingFace TGI for inference efficiency. It should be noted that this would not be a problem in cases where padding of the batch is not needed i.e if a batch has all examples of equal length or when using dense packing of examples (as is the case for pretraining models).

In this blog, we give the theoretical memory consumptions for naive attention, FlashAttention with padded transformer blocks (current implementation in HuggingFace transformers library) and the Padding-Free transformer blocks.

Lets assume an input batch of embeddings of shape (b,max{si},h)(b, max\{s_i\}, h) as input to a transformer layer where bb, sis_i and hh denote the batch size, unpadded sequence length of the ithi^{th} example in the batch and the hidden size for the transformer model respectively. For training the model, each transformer layer needs to cache activations of each operation (computed in the forward pass) for the backward pass. We assume 16-bit precision for training (2 bytes per value in a tensor). We also assume Multi-Head Attention [6] with aa attention heads here for simplicity. Though same idea also applies to Multi-Query Attention [7] and Grouped-Query Attention [8].

Naive Attention

  1. Input Layernorm

    The input LayerNorm receives an input of shape (b,max{si},h)(b, max\{s_i\}, h) which needs to be cached for backward pass. The mean and variance also need to be cached which are each of shape (b,max{si})(b, max\{s_i\}). Since, max{si}bh2bh\max\{s_i\}bh \gg 2bh, we can ignore the 2bh2bh elements for mean and variance. Total activation memory for this operation is 2×max{si}bhbytes. 2 \times max\{s_i\}bh \text{\hspace{1em}bytes.}

  2. Q, K and V projection

    The input (shared among the Q, K and V projections) to the QKV projection matrices needs to be cached. It also has max{si}bhmax\{s_i\}bh elements taking 2×max{si}bh2 \times max\{s_i\}bh bytes. The outputs of each of Q, K and V projection also need to be cached, each of which has max{si}bhmax\{s_i\}bh elements taking 2×max{si}bh2 \times max\{s_i\}bh bytes each. Total activation memory for this operation is 2×max{si}bh+3×2×max{si}bh=8×max{si}bhbytes. 2 \times max\{s_i\}bh + 3 \times 2 \times max\{s_i\}bh = 8 \times max\{s_i\}bh \text{\hspace{1em}bytes.}

  3. Attention softmax computation

    The output of softmax which has (max{si})2ab(max\{s_i\})^2ab elements also needs to be cached. Total activation memory is for this operation is 2×(max{si})2abbytes. 2 \times (max\{s_i\})^2ab \text{\hspace{1em}bytes.}

  4. Attention softmax dropout

    Attention softmax has a dropout which requires saving a mask of (max{si})2ab(max\{s_i\})^2ab elements. Each element takes a byte since PyTorch doesn't allow bit tensors. The reason for this is probably an ease of implementation since GPUs are generally byte-addressable. Total memory for this operation is (max{si})2abbytes. (max\{s_i\})^2ab \text{\hspace{1em}bytes.}

  5. Multiplication of softmax dropout output with V

    The softmax dropout output has (max{si})2ab(max\{s_i\})^2ab elements which also needs to be cached. Total activation memory for this operation is 2×(max{si})2abbytes. 2 \times (max\{s_i\})^2ab \text{\hspace{1em}bytes.}

  6. Linear Projection after attention

    We cache the output of the above multiplication which is the input to the projection matrix. It has max{si}bhmax\{s_i\}bh elements. Total activation memory for this operation is 2×max{si}bhbytes. 2 \times max\{s_i\}bh \text{\hspace{1em}bytes.}

  7. Dropout

    Only the dropout mask needs to be cached. Total memory for this operation is max{si}bhbytes. max\{s_i\}bh \text{\hspace{1em}bytes.}

  8. Post Attention LayerNorm

    Same as the previous layernorm. Memory requirement is 2×max{si}bhbytes. 2 \times max\{s_i\}bh \text{\hspace{1em}bytes.}

  9. MLP

    We assume here that the feedforward hidden dimension is f=4hf = 4h as is typical for a standard transformer. Inputs to each linear layer and the input to GELU activation function needs to be cached. These take 2×max{si}bh2 \times max\{s_i\}bh, 8×max{si}bh8 \times max\{s_i\}bh bytes and 8×max{si}bh8 \times max\{s_i\}bh bytes respectively. The required memory for the MLP block is 18×max{si}bhbytes. 18 \times max\{s_i\}bh \text{\hspace{1em}bytes.}

  10. Dropout after MLP

    Memory required is same as point (8) above i.e max{si}bhbytes. max\{s_i\}bh \text{\hspace{1em}bytes.}

Summing these up, total activation memory per layer is given by: Mnaive=max{si}bh(34+5a×max{si}h) M_{naive} = max\{s_i\}bh \left(34 + \frac{5a \times max\{s_i\} }{h} \right)

Transformer Layer with FlashAttention

FlashAttention [1,2] has been integrated into the HuggingFace transformers API. The current implementation at the time of writing this blog does an unpad operation just before FlashAttention kernel is executed. This operation converts the input Q, K, V of shape (b,max{si},h)(b, max\{s_i\}, h) to shape (si,h)\left(\sum s_i, h\right) (where each example in the batch is concatenated one after the other resulting in a 2D tensor) and launches the FlashAttention kernel. Post attention computation, the output is padded again to the shape (b,max{si},h)(b, max\{s_i\}, h).

FlashAttention [1,2] avoids materializing the QKTQK^T quadratic matrix in memory and uses online softmax [3], thereby dropping the need to cache activations in point (3). Rather we only need to materialize the output matrix which has shape (si,a,ha)\left(\sum s_i, a, \frac{h}{a}\right), the 2 softmax statistics both of which have the same shape (si,a)\left(\sum s_i, a\right) and the random number generator state for the dropout which we ignore here. For the algorithm in detail, refer to FlashAttention [1,2] paper. We also need to cache the attention mask of booleans which is used for padding and unpadding. We ignore it in calculations though since its same for every layer and can be cached once for the entire transformer model and doesn't need to be cached on every layer. Thus the memory required for attention becomes 2sia(ha)+4sia=2si(h+2a)bytes. 2 \sum s_i a\left(\frac{h}{a}\right) + 4 \sum s_i a = 2 \sum s_i (h + 2a) \text{\hspace{1em}bytes.}

Thus we have the total activation memory per layer with FlashAttention [1,2] as follows: Mflash=max{si}bh(34+simax{si}b[1+2ah]) M_{flash} = max\{s_i\}bh \left(34 + \frac{\sum s_i}{max\{s_i\}b} \left[1 + \frac{2a}{h}\right] \right)

Padding-Free Transformer Layer

Since all operations (except attention) in the transformer layer are same for each token position, we can avoid the padding and unpadding operation and thus reduce the activation memory required by the transformer layer further, this requires minor changes to the HuggingFace transformers implementation. In this implementation of the transformer, there is no wasted memory for pad token positions at all! In this case, the input to the entire transformer model is of the shape (si,h)\left(\sum s_i, h\right). The memory in this case is given by Mpadding_free=(si)h(35+2ah) M_{padding\_free} = \left( \sum s_i \right) h \left(35 + \frac{2a}{h} \right)

It should be noted that Mflash=Mpadding_freeM_{flash} = M_{padding\_free} when there is no padding i.e when si=max{si}i{1,2,...,b}s_i = \max\{s_i\} \forall i \in \{1, 2, ..., b\}. This optimization is similar to running a transformer model with nested tensors. While there has been significant effort to resolve this problem by taking approches like binning examples by context lengths, these lead to model performance degradation especially during finetuning.

Motivation for using Padding-Free Transformer Layer

Now, we analyze the memory consumptions in the 3 transformer layer implementations. We assume that we have a dataset of sequences of lengths following a discrete uniform distribution i.e SiU{1,2,...,N}S_i \sim U\{1, 2, ..., N\}, where SiS_i is the random variable denoting the sequence length of ithi^{th} sample in the batch and NN is the maximum sequence length for the dataset and the model. We sample batches with bb examples each, with sequences of lengths (S1,S2,...,Sb)(S_1, S_2, ..., S_b). We compute the expectation E[Mnaive]\mathbb{E}[M_{naive}], E[Mflash]\mathbb{E}[M_{flash}] and E[Mpadding_free]\mathbb{E}[M_{padding\_free}] under the discrete uniform distribution. To do so, we consider another random variable K=max{Si}K = max\{S_i\}. The Cumulative Distribution Function for KK can be derived as: P(Kk)=P(max{Si}k)=P(S1k,S2k,...,Sbk) P(K \le k) = P(max\{S_i\} \le k) = P(S_1 \le k, S_2 \le k, ..., S_b \le k) Now, using the fact that examples in a batch are i.i.d, we have     P(Kk)=[P(Sik)]b=(kN)b \implies P(K \le k) = [P(S_i \le k)] ^ b = \left( \frac{k}{N} \right) ^ b and thus we have the Probability Mass Function for KK as: P(K=k)=P(Kk)P(Kk1)=(kN)b(k1N)b P(K = k) = P(K \le k) - P(K \le k - 1) = \left(\frac{k}{N}\right) ^ b - \left(\frac{k - 1}{N}\right) ^ b We can use computational methods or Faulhaber's formula [9] with the aforementioned derived result to calculate the expectations of the memory usage in the 3 methods. We report the theoretical memory consumption derived using the equations for a 20B parameter model in the following table. We find that using a Padding-Free version of the transformer layer saves 43%\sim43\% activation memory and also saves a lot of redundant FLOPs. We leave the analysis of FLOPs out of this blog but they are easily derivable.

Sequence Length Naive Attention Flash Attention Padding-Free Transformer
512 1.085 GB 0.721 GB 0.411 GB
1024 2.919 GB 1.441 GB 0.821 GB
2048 8.837 GB 2.882 GB 1.642 GB
4096 29.674 GB 5.763 GB 3.283 GB
8192 107.347 GB 11.524 GB 6.566 GB
16384 406.693 GB 23.048 GB 13.132 GB
32768 1581.386 GB 46.096 GB 26.263 GB

Table: Memory usage per transformer layer for different attention implementations at different context lengths for a 20B parameter model with context length (N=8192)(N = 8192), hidden size (h=6144)(h = 6144), FFN hidden size (f=24576)(f = 24576), attention heads (a=48)(a = 48).

Conclusion

In this blog, we present a way to completely avoid computations and memory requirements of pad tokens during finetuning of transformer models using FlashAttention. Our changes are easily integrable into the HuggingFace transformers ecosystem for finetuning. We also derive equations for theoretical memory consumption for the same in this blog. The method doesn't involve writing any low level device code. The only non-native PyTorch code we use is FlashAttention which is already available.

References

  1. Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." Advances in Neural Information Processing Systems 35 (2022): 16344-16359.
  2. Dao, Tri. "Flashattention-2: Faster attention with better parallelism and work partitioning." arXiv preprint arXiv:2307.08691 (2023).
  3. Milakov, Maxim, and Natalia Gimelshein. "Online normalizer calculation for softmax." arXiv preprint arXiv:1805.02867 (2018).
  4. Ivanov, Andrei, et al. "Data movement is all you need: A case study on optimizing transformers." Proceedings of Machine Learning and Systems 3 (2021): 711-732.
  5. Korthikanti, Vijay Anand, et al. "Reducing activation recomputation in large transformer models." Proceedings of Machine Learning and Systems 5 (2023).
  6. Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).
  7. Shazeer, Noam. "Fast transformer decoding: One write-head is all you need." arXiv preprint arXiv:1911.02150 (2019).
  8. Ainslie, Joshua, et al. "Gqa: Training generalized multi-query transformer models from multi-head checkpoints." arXiv preprint arXiv:2305.13245 (2023).
  9. Knuth, Donald E. "Johann Faulhaber and sums of powers." Mathematics of Computation 61.203 (1993): 277-294.