Spaces:
Build error
Build error
| #!/usr/bin/env python | |
| # This script combines the 2 steps of | |
| # 1. calling zero_to_fp32.py to reconsolidate the shared deepspeed checkpoint | |
| # 2. then resaving it as HF checkpoint, which also takes care of sharding large checkpoints | |
| # | |
| # example usage: | |
| # | |
| # this will generate the converted checkpoint under save_dir/opt_step-40/unwrapped_model | |
| # | |
| # ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40 | |
| # | |
| # or you can override the destination by passing an explicit target dir, e.g.: | |
| # | |
| # ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40 save_dir/opt_step-40/output_dir | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint | |
| # auto-append the repo path to load m4 modules from instead of needing to set PYTHONPATH | |
| repodir = str(Path(__file__).resolve().parents[2]) | |
| sys.path.insert(0, repodir) | |
| import m4.models | |
| from m4.testing_utils import read_json_file | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/to/opt_step-100" | |
| ) | |
| parser.add_argument( | |
| "output_dir", | |
| type=str, | |
| nargs="?", | |
| help="path to pass to save_pretrained, defaults to 'unwrapped_model' relative to the checkpoint_dir argument", | |
| ) | |
| args = parser.parse_args() | |
| checkpoint_dir = Path(args.checkpoint_dir) | |
| config_dir = checkpoint_dir / "unwrapped_model" | |
| ds_checkpoint_dir = checkpoint_dir / "accelerator_state" | |
| config_file_path = config_dir / "config.json" | |
| if args.output_dir is None: | |
| output_dir = checkpoint_dir / "unwrapped_model" | |
| else: | |
| output_dir = args.output_dir | |
| config = read_json_file(config_file_path) | |
| config_class = m4.models._SUPPORTED_MODELS.get(config["model_type"], None) | |
| if config_class is None: | |
| raise ValueError(f"{config['model_type']=} isn't supported by m4") | |
| modeling_class = m4.models.model_type_to_modeling_class.get(config["model_type"], None) | |
| print(f"Detected {config_class}") | |
| print("Reconsolidating fp32 model from checkpoint shards (can take a long time)") | |
| state_dict = get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) # already on cpu | |
| # Keeping debug to use if you ever need to debug state dict | |
| # print("Saved State Dict") | |
| # for k, v in state_dict.items(): | |
| # print(f"{k} {v.shape}") | |
| kwargs = {} | |
| print(f"Loading config from {config_dir}") | |
| model_config = config_class.from_pretrained(config_dir) | |
| print(f"Instantiating a {modeling_class} model in bf16") | |
| model = modeling_class.from_pretrained( | |
| None, config=model_config, state_dict=state_dict, torch_dtype=torch.bfloat16 | |
| ) | |
| # Keeping debug to use if you ever need to debug state dict | |
| # print("Model State Dict") | |
| # for k, v in model.state_dict().items(): | |
| # print(f"{k} {v.shape}") | |
| print(f"Saving model to {output_dir}") | |
| model.save_pretrained(output_dir) | |