File size: 7,023 Bytes
217780a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#!/usr/bin/env python

#
# This tool uploads any new deepspeed checkpoints found at given path to s3 (and also various non-checkpoint files, like logs)
#
# Example:
#
# ./s3-upload-checkpoints.py checkpoints-path
#
# Use `-h` for more options
#


import argparse
import subprocess
import sys
import time
from pathlib import Path


repo_path = Path(__file__).resolve().parents[2]
zero_checkpoint_to_hf_path = repo_path / "m4/models/zero_checkpoint_to_hf.py"

RETRIES = 5

# what dir/file glob patterns to include in the upload besides checkpoints
include_patterns = ["tb_run_*", "logs", "config.yaml"]


# we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
# rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
#
# If that file is there it means:
#
# 1. either the upload is still running
# 2. the upload got aborted (e.g. cpu-oom)
#
# to detect aborted uploads we will check if the control file is older than a reasonable time to perform such a upload
control_file_name = "started-upload-checkpoint"
finished_uploading_file_name = "finished-upload-checkpoint"
# should fine tune - but surely 2h per checkpoint is plenty
reasonable_upload_time_in_secs = 2 * 60 * 60


def run_cmd(cmd, check=True):
    try:
        response = subprocess.run(
            cmd,
            stderr=subprocess.PIPE,
            stdout=subprocess.PIPE,
            check=check,
            encoding="utf-8",
        ).stdout.strip()
    except subprocess.CalledProcessError as exc:
        raise EnvironmentError(exc.stderr)

    return response


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
    # parser.add_argument("experiment_name", type=str, help="experiment name as a s3 sub-dir")
    parser.add_argument("-f", "--force", action="store_true", help="force uploading of all checkpoints")
    parser.add_argument(
        "--skip-conversion-check", action="store_true", help="skip checkpoint conversion is done check"
    )
    return parser.parse_args()


def exit(msg):
    print(msg)
    sys.exit()


def should_process(path, force, control_file_path, finished_uploading_file_path, args):
    """Heuristics to decide whether to upload this opt_step-XXX checkpoint or not"""

    # check if checkpoint is fully saved
    finished_saving_path = path / "finished-saving"  # defined in from trainer.py
    if not finished_saving_path.exists():
        print(f"[N] {path} isn't finished saving. Skipping")
        return False

    if force:
        print("[Y] Forced to re-process {checkpoint_dir}")
        return True

    # check if already uploaded
    if finished_uploading_file_path.exists():
        print(f"[N] {path} has already been uploaded. Skipping")
        return False

    # check conversion is completed
    if not args.skip_conversion_check:
        converted_model_path_1 = path / "unwrapped_model" / "pytorch_model.bin.index.json"
        converted_model_path_2 = path / "unwrapped_model" / "pytorch_model.bin"
        if not converted_model_path_1.exists() and not converted_model_path_2.exists():
            print(f"[N] {path} doesn't have a converted model. Skipping")
            return False

    # complicated checks - has another job already started uploading? or did it crash?
    if control_file_path.exists():
        if control_file_path.stat().st_mtime < time.time() - reasonable_upload_time_in_secs:
            print(f"[Y] {path} looks stale - probably aborted job. Re-uploading")
            return True
        else:
            print(
                f"[N] {path} either another job is uploading it or less than"
                f" {reasonable_upload_time_in_secs} secs has passed since it was launched. Skipping"
            )
            return False
    else:
        print(f"[Y] {path} is a new checkpoint. Uploading")
        return True


def main():
    args = get_args()

    checkpoints_path = Path(args.checkpoints_path)
    if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
        raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")

    checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
    if len(checkpoint_dirs) == 0:
        exit("No checkpoints found, exiting")

    exp_name = checkpoints_path.name

    # Check each folder in real time to allow for overlapping jobs starting at different times
    for checkpoint_dir in checkpoint_dirs:
        print(f"\n*** Checking {checkpoint_dir}")

        control_file_path = checkpoint_dir / control_file_name
        finished_uploading_file_path = checkpoint_dir / finished_uploading_file_name

        if not should_process(checkpoint_dir, args.force, control_file_path, finished_uploading_file_path, args):
            continue

        opt_step = checkpoint_dir.name
        bucket_name = "m4-exps"
        bucket_path = f"{exp_name}/{opt_step}"

        print(f"Launching upload for {checkpoint_dir} - it could take a long time")
        cmd = f"s5cmd sync {checkpoint_dir}/ s3://{bucket_name}/{bucket_path}/".split()
        # we could use flock here, to avoid a race condition, but it'd be pointless since each
        # cronjob is likely to run on a different node and flock only works within a single node
        control_file_path.touch()
        # print(f"mock running {cmd}")

        # s5cmd will fail with an error like this when MD5 checksum doesn't match on upload (it won't retry)
        # ERROR "cp data4.tar s3://m4-datasets/cm4-test/data4.tar": InvalidDigest: The Content-MD5
        # you specified was invalid. status code: 400, request id: SZEHBJ4QQ33JSMH7, host id:
        # XTeMYKd2KECiVKbFnwVbXo3LgnuA2OHWk5S+tHKAOKO95Os/pje2ZEbCfO5pojQtCTFOovvnVME=

        tries = 0
        while tries < RETRIES:
            tries += 1
            try:
                response = run_cmd(cmd)
                print(response)
                break
            except EnvironmentError as e:
                if "InvalidDigest" in str(e):
                    print(f"MD5 checksum failed, upload retry {tries}")
                    continue
            except Exception:
                # some other possible failure?
                raise

        # for now disable this as large files don't have sha256 checksums
        # result = integrity_check_recursive(checkpoint_dir, bucket_name, bucket_path)
        # print(f"Integrity check was {result}")

        control_file_path.unlink()
        finished_uploading_file_path.touch()

    # now upload non-checkpoint files
    print("\n*** Uploading non-checkpoint files")
    upload_dirs = []
    for pat in include_patterns:
        upload_dirs += list(checkpoints_path.glob(pat))

    for dir in upload_dirs:
        print(f"Launching upload for {dir}")
        cmd = f"s5cmd sync {dir} s3://m4-exps/{exp_name}/".split()
        print(f"running {cmd}")
        response = run_cmd(cmd)
        print(response)


if __name__ == "__main__":
    main()