HPSv2 / src /training /file_utils.py
tgxs002's picture
init
54199b6
import logging
import os
import multiprocessing
import subprocess
import time
import fsspec
import torch
import json
from tqdm import tqdm
from .train import unwrap_model
def remote_sync_s3(local_dir, remote_dir):
# skip epoch_latest which can change during sync.
result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode != 0:
logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}")
return False
logging.info(f"Successfully synced with S3 bucket")
return True
def remote_sync_fsspec(local_dir, remote_dir):
# FIXME currently this is slow and not recommended. Look into speeding up.
a = fsspec.get_mapper(local_dir)
b = fsspec.get_mapper(remote_dir)
for k in a:
# skip epoch_latest which can change during sync.
if 'epoch_latest.pt' in k:
continue
logging.info(f'Attempting to sync {k}')
if k in b and len(a[k]) == len(b[k]):
logging.debug(f'Skipping remote sync for {k}.')
continue
try:
logging.info(f'Successful sync for {k}.')
b[k] = a[k]
except Exception as e:
logging.info(f'Error during remote sync for {k}: {e}')
return False
return True
def remote_sync(local_dir, remote_dir, protocol):
logging.info('Starting remote sync.')
if protocol == 's3':
return remote_sync_s3(local_dir, remote_dir)
elif protocol == 'fsspec':
return remote_sync_fsspec(local_dir, remote_dir)
else:
logging.error('Remote protocol not known')
return False
def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
while True:
time.sleep(sync_every)
remote_sync(local_dir, remote_dir, protocol)
def start_sync_process(sync_every, local_dir, remote_dir, protocol):
p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol))
return p
# Note: we are not currently using this save function.
def pt_save(pt_obj, file_path):
of = fsspec.open(file_path, "wb")
with of as f:
torch.save(pt_obj, file_path)
def pt_load(file_path, map_location=None):
if file_path.startswith('s3'):
logging.info('Loading remote checkpoint, which may take a bit.')
of = fsspec.open(file_path, "rb")
with of as f:
out = torch.load(f, map_location=map_location)
return out
def check_exists(file_path):
try:
with fsspec.open(file_path):
pass
except FileNotFoundError:
return False
return True
def save_ckpt(args, model, scaler, optimizer):
assert args.save_path is not None
ckpt_path = args.save_path
model = unwrap_model(model)
checkpoint_dict = {
"iterations": args.iterations,
"name": args.name,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
if scaler is not None:
checkpoint_dict["scaler"] = scaler.state_dict()
torch.save(
checkpoint_dict,
ckpt_path,
)
logging.info(f"saved {ckpt_path}")