#!/usr/bin/env python # This script combines the 2 steps of # 1. calling zero_to_fp32.py to reconsolidate the shared deepspeed checkpoint # 2. then resaving it as HF checkpoint, which also takes care of sharding large checkpoints # # example usage: # # this will generate the converted checkpoint under save_dir/opt_step-40/unwrapped_model # # ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40 # # or you can override the destination by passing an explicit target dir, e.g.: # # ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40 save_dir/opt_step-40/output_dir import argparse import sys from pathlib import Path import torch from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint # auto-append the repo path to load m4 modules from instead of needing to set PYTHONPATH repodir = str(Path(__file__).resolve().parents[2]) sys.path.insert(0, repodir) import m4.models from m4.testing_utils import read_json_file if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/to/opt_step-100" ) parser.add_argument( "output_dir", type=str, nargs="?", help="path to pass to save_pretrained, defaults to 'unwrapped_model' relative to the checkpoint_dir argument", ) args = parser.parse_args() checkpoint_dir = Path(args.checkpoint_dir) config_dir = checkpoint_dir / "unwrapped_model" ds_checkpoint_dir = checkpoint_dir / "accelerator_state" config_file_path = config_dir / "config.json" if args.output_dir is None: output_dir = checkpoint_dir / "unwrapped_model" else: output_dir = args.output_dir config = read_json_file(config_file_path) config_class = m4.models._SUPPORTED_MODELS.get(config["model_type"], None) if config_class is None: raise ValueError(f"{config['model_type']=} isn't supported by m4") modeling_class = m4.models.model_type_to_modeling_class.get(config["model_type"], None) print(f"Detected {config_class}") print("Reconsolidating fp32 model from checkpoint shards (can take a long time)") state_dict = get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) # already on cpu # Keeping debug to use if you ever need to debug state dict # print("Saved State Dict") # for k, v in state_dict.items(): # print(f"{k} {v.shape}") kwargs = {} print(f"Loading config from {config_dir}") model_config = config_class.from_pretrained(config_dir) print(f"Instantiating a {modeling_class} model in bf16") model = modeling_class.from_pretrained( None, config=model_config, state_dict=state_dict, torch_dtype=torch.bfloat16 ) # Keeping debug to use if you ever need to debug state dict # print("Model State Dict") # for k, v in model.state_dict().items(): # print(f"{k} {v.shape}") print(f"Saving model to {output_dir}") model.save_pretrained(output_dir)