dummy_m4 / m4 /models /zero_checkpoint_to_hf.py
ysharma's picture
ysharma HF staff
Duplicate from HuggingFaceM4/m4-dialogue
e7d3e35
raw
history blame
3.05 kB
#!/usr/bin/env python
# This script combines the 2 steps of
# 1. calling zero_to_fp32.py to reconsolidate the shared deepspeed checkpoint
# 2. then resaving it as HF checkpoint, which also takes care of sharding large checkpoints
#
# example usage:
#
# this will generate the converted checkpoint under save_dir/opt_step-40/unwrapped_model
#
# ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40
#
# or you can override the destination by passing an explicit target dir, e.g.:
#
# ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40 save_dir/opt_step-40/output_dir
import argparse
import sys
from pathlib import Path
import torch
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
# auto-append the repo path to load m4 modules from instead of needing to set PYTHONPATH
repodir = str(Path(__file__).resolve().parents[2])
sys.path.insert(0, repodir)
import m4.models
from m4.testing_utils import read_json_file
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/to/opt_step-100"
)
parser.add_argument(
"output_dir",
type=str,
nargs="?",
help="path to pass to save_pretrained, defaults to 'unwrapped_model' relative to the checkpoint_dir argument",
)
args = parser.parse_args()
checkpoint_dir = Path(args.checkpoint_dir)
config_dir = checkpoint_dir / "unwrapped_model"
ds_checkpoint_dir = checkpoint_dir / "accelerator_state"
config_file_path = config_dir / "config.json"
if args.output_dir is None:
output_dir = checkpoint_dir / "unwrapped_model"
else:
output_dir = args.output_dir
config = read_json_file(config_file_path)
config_class = m4.models._SUPPORTED_MODELS.get(config["model_type"], None)
if config_class is None:
raise ValueError(f"{config['model_type']=} isn't supported by m4")
modeling_class = m4.models.model_type_to_modeling_class.get(config["model_type"], None)
print(f"Detected {config_class}")
print("Reconsolidating fp32 model from checkpoint shards (can take a long time)")
state_dict = get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) # already on cpu
# Keeping debug to use if you ever need to debug state dict
# print("Saved State Dict")
# for k, v in state_dict.items():
# print(f"{k} {v.shape}")
kwargs = {}
print(f"Loading config from {config_dir}")
model_config = config_class.from_pretrained(config_dir)
print(f"Instantiating a {modeling_class} model in bf16")
model = modeling_class.from_pretrained(
None, config=model_config, state_dict=state_dict, torch_dtype=torch.bfloat16
)
# Keeping debug to use if you ever need to debug state dict
# print("Model State Dict")
# for k, v in model.state_dict().items():
# print(f"{k} {v.shape}")
print(f"Saving model to {output_dir}")
model.save_pretrained(output_dir)