File size: 6,554 Bytes
217780a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python

#
# This tool converts any deepspeed checkpoints found at given path to hf format
#
# Example:
#
# ./convert-checkpoints.py checkpoints-path
#

import argparse
import subprocess
import sys
from pathlib import Path

import boto3


def check_s3_directory(directory_path):
    s3 = boto3.client("s3")

    # Add a trailing slash to the directory path
    if not directory_path.endswith("/"):
        directory_path += "/"

    # Check if any objects exist with the given directory prefix
    response = s3.list_objects_v2(Bucket="m4-exps", Prefix=directory_path)

    # If any objects are found, the directory exists
    if "Contents" in response:
        return True

    return False


def check_s3_file(file_key):
    s3 = boto3.client("s3")

    try:
        s3.head_object(Bucket="m4-exps", Key=file_key)
        return True
    except Exception:
        return False


def run_cmd(cmd, check=True):
    try:
        response = subprocess.run(
            cmd,
            stderr=subprocess.PIPE,
            stdout=subprocess.PIPE,
            check=check,
            encoding="utf-8",
        ).stdout.strip()
    except subprocess.CalledProcessError as exc:
        raise EnvironmentError(exc.stderr)

    return response


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("run_name", type=str, help="run name")
    parser.add_argument("opt_step_num_list", nargs="+", help="list of opt-steps to download")
    parser.add_argument("repo_path", type=str, help="repo path")

    parser.add_argument("-f", "--force", action="store_true", help="force rebuilding of all checkpoints")
    return parser.parse_args()


def exit(msg):
    print(msg)
    sys.exit()


def cmd_retry_loop(cmd, max_retries=5):
    # s5cmd will fail with an error like this when MD5 checksum doesn't match on upload (it won't retry)
    # ERROR "cp data4.tar s3://m4-datasets/cm4-test/data4.tar": InvalidDigest: The Content-MD5
    # you specified was invalid. status code: 400, request id: SZEHBJ4QQ33JSMH7, host id:
    # XTeMYKd2KECiVKbFnwVbXo3LgnuA2OHWk5S+tHKAOKO95Os/pje2ZEbCfO5pojQtCTFOovvnVME=

    tries = 0
    while tries < max_retries:
        tries += 1
        try:
            response = run_cmd(cmd)
            print(response)
            break
        except EnvironmentError as e:
            if "InvalidDigest" in str(e):
                print(f"MD5 checksum failed, download retry {tries}")
                continue
        except Exception:
            # some other possible failure?
            raise
    return response


def main():
    args = get_args()

    run_name = args.run_name
    opt_step_num_list = args.opt_step_num_list
    repo_path = Path(args.repo_path)
    zero_checkpoint_to_hf_path = repo_path / "m4/models/zero_checkpoint_to_hf.py"
    bucket_name = "m4-exps"
    opt_step_s3_file_keys = [f"{run_name}/opt_step-{opt_step_num}" for opt_step_num in opt_step_num_list]

    check_s3_directory(run_name)

    # Check each folder in real time to allow for overlapping jobs starting at different times
    for opt_step_s3_file_key in opt_step_s3_file_keys:
        print(f"\n*** Checking {opt_step_s3_file_key}")
        if not check_s3_directory(opt_step_s3_file_key):
            print(f"The checkpoint {opt_step_s3_file_key} does not exist - skipping")
            continue
        unwrapped_model_s3_file_key = f"{opt_step_s3_file_key}/unwrapped_model"
        bin_s3_file_key = f"{unwrapped_model_s3_file_key}/pytorch_model.bin"
        index_s3_file_key = f"{unwrapped_model_s3_file_key}/pytorch_model.bin.index.json"
        is_not_converted = not check_s3_file(bin_s3_file_key) and not check_s3_file(index_s3_file_key)
        if is_not_converted:
            print(
                f"The checkpoint hasn't been converted, launching download for {opt_step_s3_file_key} - it could take"
                " a long time"
            )

            opt_step_dirname = opt_step_s3_file_key.split("/")[-1]
            cluster_opt_step_dir = f"/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/{run_name}/{opt_step_dirname}"
            cmd = f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/* {cluster_opt_step_dir}".split()
            download_response_opt_step_dir = cmd_retry_loop(cmd, max_retries=5)
            print(f"download_response_opt_step_dir: {download_response_opt_step_dir}")
        else:
            print(
                "The checkpoint has been converted already, downloading only the unwrapped checkpoint and"
                " tokenizer dir"
            )
            opt_step_dirname = opt_step_s3_file_key.split("/")[-1]
            cluster_opt_step_dir = f"/fsx/m4/experiments/local_experiment_dir/s3_async_temporary_checkpoint_folder/{run_name}/{opt_step_dirname}"
            unwrapped_model_dir = f"{cluster_opt_step_dir}/unwrapped_model"
            tokenizer_dir = f"{cluster_opt_step_dir}/tokenizer"
            cmd_model = (
                f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model/* {unwrapped_model_dir}".split()
            )
            cmd_tokenizer = f"s5cmd sync s3://{bucket_name}/{opt_step_s3_file_key}/tokenizer/* {tokenizer_dir}".split()
            download_response_model = cmd_retry_loop(cmd_model, max_retries=5)
            print(f"download_response_model: {download_response_model}")
            download_response_tokenizer = cmd_retry_loop(cmd_tokenizer, max_retries=5)
            print(f"download_response_tokenizer: {download_response_tokenizer}")

        print(f"opt_step_dirname: {opt_step_dirname} downloaded to cluster_opt_step_dir: {cluster_opt_step_dir}")

        if is_not_converted:
            print(f"Converting {cluster_opt_step_dir}")
            convert_cmd = [zero_checkpoint_to_hf_path, cluster_opt_step_dir]
            conversion_response = run_cmd(convert_cmd)
            print(f"conversion_response: {conversion_response}")
            print(f"upload converted checkpoint: {cluster_opt_step_dir}")
            upload_cmd = (
                f"s5cmd sync {cluster_opt_step_dir}/unwrapped_model/"
                f" s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model/ ".split()
            )
            upload_response = cmd_retry_loop(upload_cmd, max_retries=5)
            print(f"upload_response: {upload_response}")
            print(
                f"Uploaded {cluster_opt_step_dir}/unwrapped_model to"
                f" s3://{bucket_name}/{opt_step_s3_file_key}/unwrapped_model"
            )


if __name__ == "__main__":
    main()