Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.distributed as dist | |
import json | |
def create_ds_config(args, config, cfgdir): | |
config.deepspeed_config = os.path.join(cfgdir, f"deepspeed_config_{dist.get_rank()}.json") | |
opt_lower = config.trainer.optimizer.lower() | |
assert opt_lower == 'adamw', "deepspeed only support adamw" | |
with open(config.deepspeed_config, mode="w") as writer: | |
ds_config = { | |
"train_batch_size": config.data.params.batch_size * config.trainer.accumulate_grad_batches * dist.get_world_size(), | |
"train_micro_batch_size_per_gpu": config.data.params.batch_size, | |
"steps_per_print": 10, | |
"optimizer": { | |
"type": "Adam", | |
"adam_w_mode": True, | |
"params": { | |
"lr": config.model.base_learning_rate, | |
"weight_decay": config.model.weight_decay, | |
"bias_correction": True, | |
"betas": [ | |
0.9, 0.999 | |
], | |
"eps": 1e-8 | |
} | |
}, | |
} | |
if 'fp32' in config.model.params.deepspeed: | |
ds_config["fp16"] = { | |
"enabled": False} | |
else: | |
ds_config["fp16"] = { | |
"enabled": True, | |
"loss_scale": 0, | |
"initial_scale_power": config.trainer.initial_scale, | |
"loss_scale_window": 128} | |
if config.trainer.clip_grad > 0.0: | |
ds_config["gradient_clipping"] = config.trainer.clip_grad | |
zero_opt = int(config.model.params.deepspeed.split('_')[-1]) | |
if zero_opt == 1: | |
ds_config["zero_optimization"] = {"stage": zero_opt} | |
elif zero_opt == 2: | |
ds_config["zero_optimization"] = { | |
"stage": 2, | |
"offload_optimizer": { | |
"device": "cpu", | |
}, | |
"contiguous_gradients": True, | |
"overlap_comm": True | |
} | |
writer.write(json.dumps(ds_config, indent=2)) | |