File size: 3,046 Bytes
e7d3e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python

# This script combines the 2 steps of
# 1. calling zero_to_fp32.py to reconsolidate the shared deepspeed checkpoint
# 2. then resaving it as HF checkpoint, which also takes care of sharding large checkpoints
#
# example usage:
#
# this will generate the converted checkpoint under save_dir/opt_step-40/unwrapped_model
#
# ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40
#
# or you can override the destination by passing an explicit target dir, e.g.:
#
# ./m4/models/zero_checkpoint_to_hf.py save_dir/opt_step-40 save_dir/opt_step-40/output_dir

import argparse
import sys
from pathlib import Path

import torch
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint


# auto-append the repo path to load m4 modules from instead of needing to set PYTHONPATH
repodir = str(Path(__file__).resolve().parents[2])
sys.path.insert(0, repodir)

import m4.models
from m4.testing_utils import read_json_file


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/to/opt_step-100"
    )
    parser.add_argument(
        "output_dir",
        type=str,
        nargs="?",
        help="path to pass to save_pretrained, defaults to 'unwrapped_model' relative to the checkpoint_dir argument",
    )
    args = parser.parse_args()

    checkpoint_dir = Path(args.checkpoint_dir)
    config_dir = checkpoint_dir / "unwrapped_model"
    ds_checkpoint_dir = checkpoint_dir / "accelerator_state"
    config_file_path = config_dir / "config.json"

    if args.output_dir is None:
        output_dir = checkpoint_dir / "unwrapped_model"
    else:
        output_dir = args.output_dir

    config = read_json_file(config_file_path)
    config_class = m4.models._SUPPORTED_MODELS.get(config["model_type"], None)
    if config_class is None:
        raise ValueError(f"{config['model_type']=} isn't supported by m4")
    modeling_class = m4.models.model_type_to_modeling_class.get(config["model_type"], None)

    print(f"Detected {config_class}")

    print("Reconsolidating fp32 model from checkpoint shards (can take a long time)")
    state_dict = get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)  # already on cpu

    # Keeping debug to use if you ever need to debug state dict
    # print("Saved State Dict")
    # for k, v in state_dict.items():
    #     print(f"{k} {v.shape}")

    kwargs = {}
    print(f"Loading config from {config_dir}")
    model_config = config_class.from_pretrained(config_dir)

    print(f"Instantiating a {modeling_class} model in bf16")
    model = modeling_class.from_pretrained(
        None, config=model_config, state_dict=state_dict, torch_dtype=torch.bfloat16
    )

    # Keeping debug to use if you ever need to debug state dict
    # print("Model State Dict")
    # for k, v in model.state_dict().items():
    #     print(f"{k} {v.shape}")

    print(f"Saving model to {output_dir}")
    model.save_pretrained(output_dir)