DeepSpeed
DeepSpeed implements everything described in the ZeRO paper. Currently it provides full support for:
- Optimizer state partitioning (ZeRO stage 1)
- Gradient partitioning (ZeRO stage 2)
- Parameter partitioning (ZeRO stage 3)
- Custom mixed precision training handling
- A range of fast CUDA-extension-based optimizers
- ZeRO-Offload to CPU and Disk/NVMe
ZeRO-Offload has its own dedicated paper: ZeRO-Offload: Democratizing Billion-Scale Model Training. And NVMe-support is described in the paper ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning.
DeepSpeed ZeRO-2 is primarily used only for training, as its features are of no use to inference.
DeepSpeed ZeRO-3 can be used for inference as well, since it allows huge models to be loaded on multiple GPUs, which won’t be possible on a single GPU.
🤗 Accelerate integrates DeepSpeed via 2 options:
- Integration of the DeepSpeed features via
deepspeed config file
specification inaccelerate config
. You just supply your custom config file or use our template. Most of this document is focused on this feature. This supports all the core features of DeepSpeed and gives user a lot of flexibility. User may have to change few lines of code depending on the config. - Integration via
deepspeed_plugin
.This supports subset of the DeepSpeed features and uses default options for the rest of the configurations. User need not change any code and is good for those who are fine with most of the default settings of DeepSpeed.
What is integrated?
Training:
- DeepSpeed ZeRO training supports the full ZeRO stages 1, 2 and 3 as well as CPU/Disk offload of optimizer states, gradients and parameters. Below is a short description of Data Parallelism using ZeRO - Zero Redundancy Optimizer along with diagram from this blog post
(Source: link)
a. Stage 1 : Shards optimizer states across data parallel workers/GPUs
b. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs
c. Stage 3: Shards optimizer states + gradients + model parameters across data parallel workers/GPUs
d. Optimizer Offload: Offloads the gradients + optimizer states to CPU/Disk building on top of ZERO Stage 2
e. Param Offload: Offloads the model parameters to CPU/Disk building on top of ZERO Stage 3
Note: With respect to Disk Offload, the disk should be an NVME for decent speed but it technically work on any DiskInference:
- DeepSpeed ZeRO Inference supports ZeRO stage 3 with ZeRO-Infinity. It uses the same ZeRO protocol as training, but it doesn’t use an optimizer and a lr scheduler and only stage 3 is relevant. For more details see: deepspeed-zero-inference.
How it works?
Pre-Requisites: Install DeepSpeed version >=0.6.5. Please refer to the DeepSpeed Insallation details for more information.
We will first look at easy to use integration via accelerate config
.
Followed by more flexible and feature rich deepspeed config file
integration.
Accelerate DeepSpeed Plugin
On your machine(s) just run:accelerate config
and answer the questions asked. It will ask whether you want to use a config file for DeepSpeed to which you should answer no. Then answer the following questions to generate a basic DeepSpeed config. This will generate a config file that will be used automatically to properly set the default options when doing
accelerate launch my_script.py --args_to_my_script
For instance, here is how you would run the NLP example examples/nlp_example.py
(from the root of the repo) with DeepSpeed Plugin:
ZeRO Stage-2 DeepSpeed Plugin Example
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero_stage: 2
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: false
accelerate launch examples/nlp_example.py --mixed_precision fp16
ZeRO Stage-3 with CPU Offload DeepSpeed Plugin Example
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: false
accelerate launch examples/nlp_example.py --mixed_precision fp16
Currently, Accelerate
supports following config through the CLI:
`zero_stage`: [0] Disabled, [1] optimizer state partitioning, [2] optimizer+gradient state partitioning and [3] optimizer+gradient+parameter partitioning
`gradient_accumulation_steps`: Number of training steps to accumulate gradients before averaging and applying them.
`gradient_clipping`: Enable gradient clipping with value.
`offload_optimizer_device`: [none] Disable optimizer offloading, [cpu] offload optimizer to CPU, [nvme] offload optimizer to NVMe SSD. Only applicable with ZeRO >= Stage-2.
`offload_param_device`: [none] Disable parameter offloading, [cpu] offload parameters to CPU, [nvme] offload parameters to NVMe SSD. Only applicable with ZeRO Stage-3.
`zero3_init_flag`: Decides whether to enable `deepspeed.zero.Init` for constructing massive models. Only applicable with ZeRO Stage-3.
`zero3_save_16bit_model`: Decides whether to save 16-bit model weights when using ZeRO Stage-3.
`mixed_precision`: `no` for FP32 training, `fp16` for FP16 mixed-precision training and `bf16` for BF16 mixed-precision training.
To be able to tweak more options, you will need to use a DeepSpeed config file.
DeepSpeed Config File
On your machine(s) just run:accelerate config
and answer the questions asked. It will ask whether you want to use a config file for deepspeed to which you answer yes and provide the path to the deepspeed config file. This will generate a config file that will be used automatically to properly set the default options when doing
accelerate launch my_script.py --args_to_my_script
For instance, here is how you would run the NLP example examples/by_feature/deepspeed_with_config_support.py
(from the root of the repo) with DeepSpeed Config File:
ZeRO Stage-2 DeepSpeed Config File Example
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: /home/ubuntu/accelerate/examples/configs/deepspeed_config_templates/zero_stage2_config.json
zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: false
with the contents of zero_stage2_config.json
being:
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": "auto",
"contiguous_gradients": true
},
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
accelerate launch examples/by_feature/deepspeed_with_config_support.py \
--config_name "gpt2-large" \
--tokenizer_name "gpt2-large" \
--dataset_name "wikitext" \
--dataset_config_name "wikitext-2-raw-v1" \
--block_size 128 \
--output_dir "./clm/clm_deepspeed_stage2_accelerate" \
--learning_rate 5e-4 \
--per_device_train_batch_size 24 \
--per_device_eval_batch_size 24 \
--num_train_epochs 3 \
--with_tracking \
--report_to "wandb"\
ZeRO Stage-3 with CPU offload DeepSpeed Config File Example
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: /home/ubuntu/accelerate/examples/configs/deepspeed_config_templates/zero_stage3_offload_config.json
zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: false
with the contents of zero_stage3_offload_config.json
being:
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"sub_group_size": 1e9,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto"
},
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
accelerate launch examples/by_feature/deepspeed_with_config_support.py \
--config_name "gpt2-large" \
--tokenizer_name "gpt2-large" \
--dataset_name "wikitext" \
--dataset_config_name "wikitext-2-raw-v1" \
--block_size 128 \
--output_dir "./clm/clm_deepspeed_stage3_offload_accelerate" \
--learning_rate 5e-4 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--num_train_epochs 3 \
--with_tracking \
--report_to "wandb"\
Important code changes when using DeepSpeed Config File
DeepSpeed Optimizers and Schedulers. For more information on these, see the DeepSpeed Optimizers and DeepSpeed Schedulers documentation. We will look at the changes needed in the code when using these.
a. DS Optim + DS Scheduler: The case when both
optimizer
andscheduler
keys present in the DeepSpeed config file. In this situation, those will be used and user has to useaccelerate.utils.DummyOptim
andaccelerate.utils.DummyScheduler
to replace the PyTorch/Custom optimizers and schedulers in their code. Below is the snippet fromexamples/by_feature/deepspeed_with_config_support.py
showing this:# Creates Dummy Optimizer if `optimizer` was spcified in the config file else creates Adam Optimizer optimizer_cls = ( torch.optim.AdamW if accelerator.state.deepspeed_plugin is None or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config else DummyOptim ) optimizer = optimizer_cls(optimizer_grouped_parameters, lr=args.learning_rate) # Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler if ( accelerator.state.deepspeed_plugin is None or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config ): lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.max_train_steps, ) else: lr_scheduler = DummyScheduler( optimizer, total_num_steps=args.max_train_steps, warmup_num_steps=args.num_warmup_steps )
b. Custom Optim + Custom Scheduler: The case when both
optimizer
andscheduler
keys are absent in the DeepSpeed config file. In this situation, no code changes are needed from the user and this is the case when using integration via DeepSpeed Plugin. In the above example we can see that the code reamins unchanged if theoptimizer
andscheduler
keys are absent in the DeepSpeed config file.c. Custom Optim + DS Scheduler: The case when only
scheduler
key is present in the DeepSpeed config file. In this situation, user has to useaccelerate.utils.DummyScheduler
to replace the PyTorch/Custom scheduler in their code.d. DS Optim + Custom Scheduler: The case when only
optimizer
key is present in the DeepSpeed config file. This will result in an error because one can only use DS Scheduler when using DS Optim.Notice the
auto
values in the above example DeepSpeed config files. These are automatically handled byprepare
method based on model, dataloaders, dummy optimizer and dummy schedulers provided toprepare
method. Only theauto
fields specified in above examples are handled byprepare
method and the rest have to be explicitly specified by the user.
Saving and loading
Saving and loading of models is unchanged for ZeRO Stage-1 and Stage-2.
under ZeRO Stage-3,
state_dict
contains just the placeholders since the model weights are partitioned across multiple GPUs. ZeRO Stage-3 has 2 options:a. Saving the entire 16bit model weights to directly load later on using
model.load_state_dict(torch.load(pytorch_model.bin))
. For this, either setzero_optimization.stage3_gather_16bit_weights_on_model_save
to True in DeepSpeed Config file or setzero3_save_16bit_model
to True in DeepSpeed Plugin. Note that this option requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed. Below is the snippet fromexamples/by_feature/deepspeed_with_config_support.py
showing this:unwrapped_model = accelerator.unwrap_model(model) # New Code # # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or # `zero3_save_16bit_model` is True in DeepSpeed Plugin. # For Zero Stages 1 and 2, models are saved as usual in the output directory. # The model name saved is `pytorch_model.bin` unwrapped_model.save_pretrained( args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model), )
b. To get 32bit weights, first save the model using
model.save_checkpoint()
. Below is the snippet fromexamples/by_feature/deepspeed_with_config_support.py
showing this:success = model.save_checkpoint(PATH, ckpt_id, checkpoint_state_dict) status_msg = "checkpointing: PATH={}, ckpt_id={}".format(PATH, ckpt_id) if success: logging.info(f"Success {status_msg}") else: logging.warning(f"Failure {status_msg}")
This will create ZeRO model and optimizer partitions along with
zero_to_fp32.py
script in checkpoint directory. One can use this script to do offline consolidation.
It requires no configuration files or GPUs. Here is an example of its usage:$ cd /path/to/checkpoint_dir $ ./zero_to_fp32.py . pytorch_model.bin Processing zero checkpoint at global_step1 Detected checkpoint of type zero stage 3, world_size: 2 Saving fp32 state dict to pytorch_model.bin (total_numel=60506624)
To get 32bit model for saving/inference, one can do the following:
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint unwrapped_model = accelerator.unwrap_model(model) fp32_model = load_state_dict_from_zero_checkpoint(unwrapped_model, checkpoint_dir)
If only interested in state_dict, one can do the following:
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir)
Note that all these functions require ~2x memory (general RAM) of the size of the final checkpoint.
ZeRO Inference
DeepSpeed ZeRO Inference supports ZeRO stage 3 with ZeRO-Infinity. It uses the same ZeRO protocol as training, but it doesn't use an optimizer and a lr scheduler and only stage 3 is relevant. With accelerate integration, one has to just prepare model and dataloader as shown below:model, eval_dataloader = accelerator.prepare(model, eval_dataloader)
Few caveats to be aware of
- Current integration doesn’t support Pipeline Parallelism of DeepSpeed.
- Current integration doesn’t support
mpu
, limiting the tensor parallelism which is supported in Megatron-LM. - Current integration doesn’t support multiple models for a given
accelerator
object.
Internals
class accelerate.DeepSpeedPlugin
< source >( hf_ds_config: typing.Any = None gradient_accumulation_steps: int = None gradient_clipping: float = None zero_stage: int = None is_train_batch_min: str = True offload_optimizer_device: bool = None offload_param_device: bool = None zero3_init_flag: bool = None zero3_save_16bit_model: bool = None )
This plugin is used to integrate DeepSpeed.
deepspeed_config_process
< source >( prefix = '' mismatches = None config = None must_match = True **kwargs )
Process the DeepSpeed config with the values from the kwargs.
class accelerate.utils.DummyOptim
< source >( params lr = 0.001 weight_decay = 0 **kwargs )
Dummy optimizer presents model parameters or param groups, this is primarily used to follow conventional training loop when optimizer config is specified in the deepspeed config file.
class accelerate.utils.DummyScheduler
< source >( optimizer total_num_steps = None warmup_num_steps = 0 **kwargs )
Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training loop when scheduler config is specified in the deepspeed config file.
class accelerate.utils.DeepSpeedEngineWrapper
< source >( engine )
Internal wrapper for deepspeed.runtime.engine.DeepSpeedEngine. This is used to follow conventional training loop.
class accelerate.utils.DeepSpeedOptimizerWrapper
< source >( optimizer )
Internal wrapper around a deepspeed optimizer.
class accelerate.utils.DeepSpeedSchedulerWrapper
< source >( scheduler optimizers )
Internal wrapper around a deepspeed scheduler.
Main DeepSpeed Resources
Papers:
- ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
- ZeRO-Offload: Democratizing Billion-Scale Model Training
- ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning
Finally, please, remember that, 🤗 Accelerate
only integrates DeepSpeed, therefore if you
have any problems or questions with regards to DeepSpeed usage, please, file an issue with DeepSpeed GitHub.