PyTorch training on Apple silicon
Previously, training models on a Mac was limited to the CPU only. With the release of PyTorch v1.12, you can take advantage of training models with Apple’s silicon GPUs for significantly faster performance and training. This is powered in PyTorch by integrating Apple’s Metal Performance Shaders (MPS) as a backend. The MPS backend implements PyTorch operations as custom Metal shaders and places these modules on a mps
device.
Some PyTorch operations are not implemented in MPS yet and will throw an error. To avoid this, you should set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1
to use the CPU kernels instead (you’ll still see a UserWarning
).
If you run into any other errors, please open an issue in the PyTorch repository because the Trainer only integrates the MPS backend.
With the mps
device set, you can:
- train larger networks or batch sizes locally
- reduce data retrieval latency because the GPU’s unified memory architecture allows direct access to the full memory store
- reduce costs because you don’t need to train on cloud-based GPUs or add additional local GPUs
Get started by making sure you have PyTorch installed. MPS acceleration is supported on macOS 12.3+.
pip install torch torchvision torchaudio
TrainingArguments uses the mps
device by default if it’s available which means you don’t need to explicitly set the device. For example, you can run the run_glue.py script with the MPS backend automatically enabled without making any changes.
export TASK_NAME=mrpc
python examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name $TASK_NAME \
- --use_mps_device \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir /tmp/$TASK_NAME/ \
--overwrite_output_dir
Backends for distributed setups like gloo
and nccl
are not supported by the mps
device which means you can only train on a single GPU with the MPS backend.
You can learn more about the MPS backend in the Introducing Accelerated PyTorch Training on Mac blog post.
< > Update on GitHub