TRL documentation

Reducing Memory Usage

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.13.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Reducing Memory Usage

Section under construction. Feel free to contribute!

Truncation

Sequence lengths in the dataset can vary widely, and by default, TRL does not modify the data. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.

Truncation prompt completion

To reduce memory usage, it’s important to truncate sequences to a reasonable length. Even discarding just a few tokens from the dataset can result in significant memory savings by minimizing unnecessary padding. Truncation is a good practice and should always be applied to ensure efficient use of resources. While the truncation limit doesn’t need to be overly restrictive, setting a sensible value is essential for optimal performance.

DPO
SFT

DPO truncation is applied first to the prompt and to the completion via the max_prompt_length and max_completion_length parameters. The max_length parameter is then used to truncate the resulting sequence.

Truncation prompt completion

To set the truncation parameters, use the following code snippet:

from trl import DPOConfig

training_args = DPOConfig(..., max_prompt_length=..., max_completion_length=..., max_length=...)
< > Update on GitHub