Training MoE on AWS Trainium

Community Article Published May 23, 2024

Introduction

This article explains how to train a mixture of experts (MoE) model using AWS Trainium.

The code we will use is based on that used to develop KARAKURI LM 8x7B Chat v0.1, the world's first MoE model trained on Trainium.

What is Trainium?

AWS Trainium is an accelerator developed by AWS specifically for training machine learning models. Trainium is said to reduce training costs by up to 50% compared to equivalent GPU instances.

For more details about Trainium, please refer to this blog post by AWS.

What is MoE?

In the context of Transformer models, MoE is a technique that divides the feedforward layer into multiple independent units (experts) and activates only a subset of these experts during each forward pass. This approach helps reduce the increase in computational and memory footprint that typically comes with a growing number of parameters.

For more details on MoE, please refer to this blog post by Hugging Face.

Distributed Training Library

On AWS Trainium, a commonly used distributed training library is the AWS Neuron Reference for NeMo Megatron (neuronx-nemo-megatron).

The library used in this article is a modified version of neuronx-nemo-megatron. Specifically, it includes implementations for the sparse MoE layers used in the Mixtral model. Additionally, it has been adapted to support Hugging Face Datasets, making it easier to use datasets from the Hugging Face Hub.


1. Setting Up the Environment

First, we will set up the AWS infrastructure. Please refer to the following links to configure the VPC and ParallelCluster:

2. Installing Required Tools

Next, we will install the required tools by following the official AWS tutorial.

Connecting to the Head Node

First, connect to the head node via SSH:

ssh -i YOUR_KEY.pem ubuntu@HEAD_NODE_IP_ADDRESS

Activating the Virtual Environment

Activate the virtual environment:

cd ~
source ./aws_neuron_venv_pytorch/bin/activate

Cloning the Repository

Clone the repository:

cd ~
git clone https://github.com/karakuri-ai/neuronx-nemo-megatron.git
cd neuronx-nemo-megatron

Note that while the AWS tutorial uses aws-neuron/neuronx-nemo-megatron, we will use karakuri-ai/neuronx-nemo-megatron.

Building the Package and Installing Dependencies

Build the package and install the dependencies:

pip install wheel
./build.sh

pip install ./build/*.whl
pip install -r requirements.txt protobuf==3.20.3

cd ~
python -c "from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import compile_helper; compile_helper()"

3. Preparing the Dataset

Download and preprocess the dataset:

cd ~/neuronx-nemo-megatron/nemo/examples/nlp/language_modeling
python create_mixtral_sft_dataset.py

Here we use the No Robots dataset. If you wish to use a different dataset, modify the create_mixtral_sft_dataset.py script.

4. Training the Model

Converting Checkpoints (HF -> NeMo)

Convert the checkpoints from Hugging Face format to NeMo format:

cd ~/neuronx-nemo-megatron/nemo/examples/nlp/language_modeling/checkpoint_conversion
python convert_hf_checkpoint_to_nemo_mixtral.py \
  --path_to_checkpoint /path/to/hf_checkpoint \
  --config_file /path/to/hf_checkpoint/config.json \
  --model_bin_file /path/to/hf_checkpoint/pytorch_model.bin.index.json \
  --output_path /path/to/nemo_checkpoint \
  --tp_degree 8 \
  --pp_degree 8 \
  --save_bf16 True \
  --num_shards 19

If running this outside ParallelCluster, upload the converted checkpoints to the S3 bucket linked to FSx.

Editing Paths

Edit mixtral_8x7b.sh to specify the paths to the dataset and checkpoints.

Pre-compiling the Model

To pre-compile the model, run the following command:

sbatch --nodes 2 compile.slurm ./mixtral_8x7b.sh

Starting the Training

Start the training process:

sbatch --nodes 2 run.slurm ./mixtral_8x7b.sh

Converting Checkpoints (NeMo -> HF)

Once training is complete, convert the checkpoints back to Hugging Face format:

cd ~/neuronx-nemo-megatron/nemo/examples/nlp/language_modeling/checkpoint_conversion
python convert_nemo_checkpoint_to_hf_mixtral.py \
  --path_to_checkpoints /path/to/nemo_checkpoint \
  --config_file /path/to/hf_config_file \
  --output_path /path/to/hf_checkpoint \
  --is_xser True \
  --dtype bfloat16

5. Inference

Inference can be performed using GPUs or AWS Inferentia2. For implementation details on Inferentia2-based inference, please refer to the sample code provided by AWS.


Conclusion

KARAKURI LM 8x7B Chat v0.1, the first MoE model trained on Trainium, is now available on Hugging Face Hub. At the time of its release, this model achieved top-tier performance among open models on the MT-Bench-jp, a benchmark for evaluating Japanese multi-turn conversation capabilities.

Additionally, we have released a demo showcasing inference using Inferentia2.