Spaces:
Runtime error
Runtime error
File size: 7,023 Bytes
217780a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
#!/usr/bin/env python
#
# This tool uploads any new deepspeed checkpoints found at given path to s3 (and also various non-checkpoint files, like logs)
#
# Example:
#
# ./s3-upload-checkpoints.py checkpoints-path
#
# Use `-h` for more options
#
import argparse
import subprocess
import sys
import time
from pathlib import Path
repo_path = Path(__file__).resolve().parents[2]
zero_checkpoint_to_hf_path = repo_path / "m4/models/zero_checkpoint_to_hf.py"
RETRIES = 5
# what dir/file glob patterns to include in the upload besides checkpoints
include_patterns = ["tb_run_*", "logs", "config.yaml"]
# we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
# rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
#
# If that file is there it means:
#
# 1. either the upload is still running
# 2. the upload got aborted (e.g. cpu-oom)
#
# to detect aborted uploads we will check if the control file is older than a reasonable time to perform such a upload
control_file_name = "started-upload-checkpoint"
finished_uploading_file_name = "finished-upload-checkpoint"
# should fine tune - but surely 2h per checkpoint is plenty
reasonable_upload_time_in_secs = 2 * 60 * 60
def run_cmd(cmd, check=True):
try:
response = subprocess.run(
cmd,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=check,
encoding="utf-8",
).stdout.strip()
except subprocess.CalledProcessError as exc:
raise EnvironmentError(exc.stderr)
return response
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
# parser.add_argument("experiment_name", type=str, help="experiment name as a s3 sub-dir")
parser.add_argument("-f", "--force", action="store_true", help="force uploading of all checkpoints")
parser.add_argument(
"--skip-conversion-check", action="store_true", help="skip checkpoint conversion is done check"
)
return parser.parse_args()
def exit(msg):
print(msg)
sys.exit()
def should_process(path, force, control_file_path, finished_uploading_file_path, args):
"""Heuristics to decide whether to upload this opt_step-XXX checkpoint or not"""
# check if checkpoint is fully saved
finished_saving_path = path / "finished-saving" # defined in from trainer.py
if not finished_saving_path.exists():
print(f"[N] {path} isn't finished saving. Skipping")
return False
if force:
print("[Y] Forced to re-process {checkpoint_dir}")
return True
# check if already uploaded
if finished_uploading_file_path.exists():
print(f"[N] {path} has already been uploaded. Skipping")
return False
# check conversion is completed
if not args.skip_conversion_check:
converted_model_path_1 = path / "unwrapped_model" / "pytorch_model.bin.index.json"
converted_model_path_2 = path / "unwrapped_model" / "pytorch_model.bin"
if not converted_model_path_1.exists() and not converted_model_path_2.exists():
print(f"[N] {path} doesn't have a converted model. Skipping")
return False
# complicated checks - has another job already started uploading? or did it crash?
if control_file_path.exists():
if control_file_path.stat().st_mtime < time.time() - reasonable_upload_time_in_secs:
print(f"[Y] {path} looks stale - probably aborted job. Re-uploading")
return True
else:
print(
f"[N] {path} either another job is uploading it or less than"
f" {reasonable_upload_time_in_secs} secs has passed since it was launched. Skipping"
)
return False
else:
print(f"[Y] {path} is a new checkpoint. Uploading")
return True
def main():
args = get_args()
checkpoints_path = Path(args.checkpoints_path)
if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")
checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
if len(checkpoint_dirs) == 0:
exit("No checkpoints found, exiting")
exp_name = checkpoints_path.name
# Check each folder in real time to allow for overlapping jobs starting at different times
for checkpoint_dir in checkpoint_dirs:
print(f"\n*** Checking {checkpoint_dir}")
control_file_path = checkpoint_dir / control_file_name
finished_uploading_file_path = checkpoint_dir / finished_uploading_file_name
if not should_process(checkpoint_dir, args.force, control_file_path, finished_uploading_file_path, args):
continue
opt_step = checkpoint_dir.name
bucket_name = "m4-exps"
bucket_path = f"{exp_name}/{opt_step}"
print(f"Launching upload for {checkpoint_dir} - it could take a long time")
cmd = f"s5cmd sync {checkpoint_dir}/ s3://{bucket_name}/{bucket_path}/".split()
# we could use flock here, to avoid a race condition, but it'd be pointless since each
# cronjob is likely to run on a different node and flock only works within a single node
control_file_path.touch()
# print(f"mock running {cmd}")
# s5cmd will fail with an error like this when MD5 checksum doesn't match on upload (it won't retry)
# ERROR "cp data4.tar s3://m4-datasets/cm4-test/data4.tar": InvalidDigest: The Content-MD5
# you specified was invalid. status code: 400, request id: SZEHBJ4QQ33JSMH7, host id:
# XTeMYKd2KECiVKbFnwVbXo3LgnuA2OHWk5S+tHKAOKO95Os/pje2ZEbCfO5pojQtCTFOovvnVME=
tries = 0
while tries < RETRIES:
tries += 1
try:
response = run_cmd(cmd)
print(response)
break
except EnvironmentError as e:
if "InvalidDigest" in str(e):
print(f"MD5 checksum failed, upload retry {tries}")
continue
except Exception:
# some other possible failure?
raise
# for now disable this as large files don't have sha256 checksums
# result = integrity_check_recursive(checkpoint_dir, bucket_name, bucket_path)
# print(f"Integrity check was {result}")
control_file_path.unlink()
finished_uploading_file_path.touch()
# now upload non-checkpoint files
print("\n*** Uploading non-checkpoint files")
upload_dirs = []
for pat in include_patterns:
upload_dirs += list(checkpoints_path.glob(pat))
for dir in upload_dirs:
print(f"Launching upload for {dir}")
cmd = f"s5cmd sync {dir} s3://m4-exps/{exp_name}/".split()
print(f"running {cmd}")
response = run_cmd(cmd)
print(response)
if __name__ == "__main__":
main()
|