AWS Trainium & Inferentia documentation

Distributed Training with optimum-neuron

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Distributed Training with optimum-neuron

AWS Trainium instances are great to train models. They can contain up to 16 Neuron devices, each device containing 2 Neuron cores and has 32GB of memory (16GB per core). For example a trn1.32xlarge instance has 32 x 16 = 512GB of memory.

But there is a caveat: each Neuron core is an independent data-parallel worker by default. It means that the model, the gradient state and the optimizer state, amounting to approximately 4 times the model size, must fit in each of the Neuron cores (16GB) to be able to train. If that is the case, then the activations must also fit in the remaining memory.

To alleviate that, optimum-neuron supports parallelism features enabling you to harness the full power of your Trainium instance:

  1. ZeRO-1: It is an optimization of data-parallelism which consists in sharding the optimizer state (which usually represents half of the memory needed on the device) over the data-parallel ranks.
  2. Tensor Parallelism: It is a technique which consists in sharding each of your model parameters along a given dimension on multiple devices. The number of devices to shard your parameters on is called the tensor_parallel_size.
  3. Pipeline Parallelism: coming soon!

The good news is that is it possible to combine those techniques, and optimum-neuron makes it very easy!

All the example scripts provided in the optimum-neuron repo have those features implemented via the NeuronTrainer.

How to enable ZeRO-1?

Whether you use the NeuronTrainer or decide to have your own training script that uses the NeuronAccelerator, it is very easy to enable the ZeRO-1 optimization.

Via the NeuronTrainer

from optimum.neuron import NeuronTrainingArguments, NeuronTrainer

# To enable ZeRO-1, set the `zero_1` argument to `True` in the training arguments.
training_args = NeuronTrainingArguments(
    ...,
    zero_1=True,
)

trainer = NeuronTrainer(
    model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

Since the example scripts use the NeuronTrainer, you can enable ZeRO-1 when using them by add the --zero_1 flag to your command line.

For example:

torchrun --nproc_per_node=2 examples/language-modeling/run_clm.py \
    --model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v0.6 \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --per_device_train_batch_size 1 \
    --block_size 1024 \
    --bf16 \
    --zero_1 \
    --output_dir my_training/

Via the NeuronAccelerator

There is a little bit more work to do when not using the NeuronTrainer:

  1. (Optional) Wrap the optimizer class to make it lazy. When ZeRO-1 is enabled the original optimizer is overridden to use a sharded version of it. Hence, it is possible to load the original optimizer lazily so that the optimizer state is not materialized until it is actually sharded.
from torch.optim import AdamW
from optimum.neuron.distributed import make_optimizer_constructor_lazy

lazy_adamw = make_optimizer_constructor_lazy(AdamW)
  1. Set the zero_1 argument to True when instantiating the NeuronAccelerator.
accelerator = NeuronAccelerator(
    ...
    zero_1=True,
)

model = ...
lazy_optimizer = lazy_adamw(...) # Actually instantiate the optimizer.


model, optimizer = accelerator.prepare(model, lazy_optimizer)

How to enable Tensor Parallelism?

Just as for ZeRO-1, it is possible to apply Tensor Parallelism either with the NeuronTrainer or the NeuronAccelerator.

When doing Tensor Parallelism, you have different settings:

  1. The tensor_parallel_size. Ideally it should be smallest value for which the model fits.
  2. Whether or not sequence parallelism should be enabled. Sequence parallelism shards the activations on the sequence axis outside of the tensor parallel regions. It is useful because it saves memory by sharding the activations.
  3. Whether or not parallelization of the embedding layer should be done. By default it is done because it offers multiple benefits:
  • Parallelizing the embedding layer saves memory, which can enable fitting a bigger batch size and/or sequence length.
  • For language models, where the embedding layer weights and the language-modeling head weights are usually tied, the language-modeling head ends up parallel and does not require to all-gather its output since it is fed to a cross entropy loss compatible with parallelism, saving expensive communication.

On top of that, it is very important to make sure that the original model is loaded in an efficient manner: the training script is going to be called by torchrun, which will dispatch it to workers, one worker per core. If each worker (there are 32 of them in a trn1.32xlarge instance) loads the full model weights, it can take a lot of time and go out-of-memory really fast.

optimum-neuron provides a context-manager distributed.lazy_load_for_parallelism() that loads the model lazily to prevent that, only the parameters of the corresponding model shard will be materialized in each worker.

Via the NeuronTrainer

from optimum.neuron import NeuronTrainingArguments, NeuronTrainer
from optimum.neuron.distributed import lazy_load_for_parallelism

# Specify the `tensor_parallel_size` in the training arguments.
training_args = NeuronTrainingArguments(
    ...,
    tensor_parallel_size=8,
    disable_embedding_parallelization=False, # It is `False` by default.
    disable_sequence_parallel=False, # It is `False` by default.
)

with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
    model = ...


trainer = NeuronTrainer(
    model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

Since the example scripts use the NeuronTrainer, you can enable Tensor Parallelism when using them by specifying the --tensor_parallel_size argument, and optionally the disable_embedding_parallelization and disable_sequence_parallel flags. to your command line.

For example:

torchrun --nproc_per_node=2 examples/language-modeling/run_clm.py \
    --model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v0.6 \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --per_device_train_batch_size 1 \
    --block_size 1024 \
    --bf16 \
    --tensor_parallel_size 2 \
    --output_dir my_training/

Via the NeuronAccelerator

Just as for ZeRO-1, it is possible to wrap the optimizer class to make it lazy. Since the model parameters are going to be sharded, it is not needed to materialize the optimizer state prior to model parallelization: the wrapper makes sure that it stays unmaterialized.

from torch.optim import AdamW
from optimum.neuron import NeuronAccelerator
from optimum.neuron.accelerate.utils import ModelParallelismPlugin
from optimum.neuron.distributed import lazy_load_for_parallelism

tensor_parallel_size = 8
mp_plugin = ModelParallelismPlugin(
    tensor_parallel_size,
    parallelize_embeddings=True,
    sequence_parallel_enabled=True,
    checkpoint_dir=None, # Can be specified when resuming from checkpoint.
)

accelerator = NeuronAccelerator(
    ...
    mp_plugin=mp_plugin,
)

with lazy_load_for_parallelism(tensor_parallel_size=tensor_parallel_size):
    model = ...

lazy_adamw = make_optimizer_constructor_lazy(AdamW)
lazy_optimizer = lazy_adamw(...) # Actually instantiate the optimizer.

model, optimizer = accelerator.prepare(model, lazy_optimizer)

Checkpoint consolidation

Since Tensor Parallelism consists in sharding the model weights accross different workers, only sharded checkpoints will be saved during training. It is necessary to consolidate the sharded checkpoints to be able to share and use them outside of the specific training configuration there were created under.

The Optimum CLI provides a way of doing that very easily via the optimum neuron consolidate command:

optimum-cli neuron consolidate --help

usage: optimum-cli neuron consolidate [-h] [-f {pytorch,safetensors}] checkpoint_dir output_dir

positional arguments:
  checkpoint_dir        The path to the directory containing the checkpoints.
  output_dir            The path to the output directory containing the consolidated checkpoint.

optional arguments:
  -h, --help            show this help message and exit
  -f {pytorch,safetensors}, --format {pytorch,safetensors}
                        The format used to save the consolidated checkpoint.

All you need to do is specify the sharded checkpoints directory and the output directory that will contain the consolidated checkpoints, and the command takes care of the rest. It is also possible to specify the output format of the consolidated checkpoints, by default it will export them to the safetensors format, which is the recommend format to use.

Example:

Training with Tensor Parallelism just completed and the output dir is called my_training. The directory looks like the following:

my_training/
β”œβ”€β”€ README.md
β”œβ”€β”€ all_results.json 
β”œβ”€β”€ checkpoint-10 
β”‚Β Β  β”œβ”€β”€ config.json
β”‚Β Β  β”œβ”€β”€ scheduler.pt
β”‚Β Β  β”œβ”€β”€ special_tokens_map.json
β”‚Β Β  β”œβ”€β”€ tensor_parallel_shards
β”‚Β Β  β”œβ”€β”€ tokenizer.json
β”‚Β Β  β”œβ”€β”€ tokenizer.model
β”‚Β Β  β”œβ”€β”€ tokenizer_config.json
β”‚Β Β  β”œβ”€β”€ trainer_state.json
β”‚Β Β  └── training_args.bin
β”œβ”€β”€ config.json
β”œβ”€β”€ special_tokens_map.json
β”œβ”€β”€ tensor_parallel_shards
β”‚Β Β  β”œβ”€β”€ tp_rank_00_pp_rank_00
β”‚Β Β  β”œβ”€β”€ tp_rank_01_pp_rank_00
β”‚Β Β  β”œβ”€β”€ tp_rank_02_pp_rank_00
β”‚Β Β  β”œβ”€β”€ tp_rank_03_pp_rank_00
β”‚Β Β  β”œβ”€β”€ tp_rank_04_pp_rank_00
β”‚Β Β  β”œβ”€β”€ tp_rank_05_pp_rank_00
β”‚Β Β  β”œβ”€β”€ tp_rank_06_pp_rank_00
β”‚Β Β  └── tp_rank_07_pp_rank_00
β”œβ”€β”€ tokenizer.json
β”œβ”€β”€ tokenizer.model
β”œβ”€β”€ tokenizer_config.json
β”œβ”€β”€ train_results.json
β”œβ”€β”€ trainer_state.json
└── training_args.bin

It is possible to consolidate the sharded checkpoints in my_training/tensor_parallel_shards, which correspond to the sharded checkpoints saved at the end of the training, by running the following command:

optimum-cli neuron consolidate my_training my_training_consolidated_checkpoint

The sharded checkpoints are saved under a directory called tensor_parallel_shards. The optimum-cli neuron consolidate command accept as input both a directory that contains a tensor_parallel_shards directory, or the tensor_parallel_shards directory itself.