jhtonyKoo's picture
Upload 61 files
2777fde
""" Utility file for trainers """
import os
import shutil
from glob import glob
import torch
import torch.distributed as dist
''' checkpoint functions '''
# saves checkpoint
def save_checkpoint(model, \
optimizer, \
scheduler, \
epoch, \
checkpoint_dir, \
name, \
model_name):
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_state = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"epoch": epoch
}
checkpoint_path = os.path.join(checkpoint_dir,'{}_{}_{}.pt'.format(name, model_name, epoch))
torch.save(checkpoint_state, checkpoint_path)
print("Saved checkpoint: {}".format(checkpoint_path))
# reload model weights from checkpoint file
def reload_ckpt(args, \
network, \
optimizer, \
scheduler, \
gpu, \
model_name, \
manual_reload_name=None, \
manual_reload=False, \
manual_reload_dir=None, \
epoch=None, \
fit_sefa=False):
if manual_reload:
reload_name = manual_reload_name
else:
reload_name = args.name
if manual_reload_dir:
ckpt_dir = manual_reload_dir + reload_name + "/ckpt/"
else:
ckpt_dir = args.output_dir + reload_name + "/ckpt/"
temp_ckpt_dir = f'{args.output_dir}{reload_name}/ckpt_temp/'
reload_epoch = epoch
# find best or latest epoch
if epoch==None:
reload_epoch_temp = 0
reload_epoch_ckpt = 0
if len(os.listdir(temp_ckpt_dir))!=0:
reload_epoch_temp = find_best_epoch(temp_ckpt_dir)
if len(os.listdir(ckpt_dir))!=0:
reload_epoch_ckpt = find_best_epoch(ckpt_dir)
if reload_epoch_ckpt >= reload_epoch_temp:
reload_epoch = reload_epoch_ckpt
else:
reload_epoch = reload_epoch_temp
ckpt_dir = temp_ckpt_dir
else:
if os.path.isfile(f"{temp_ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt"):
ckpt_dir = temp_ckpt_dir
# reloading weight
if model_name==None:
resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{reload_epoch}.pt"
else:
resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt"
if gpu==0:
print("===Resume checkpoint from: {}===".format(resuming_path))
loc = 'cuda:{}'.format(gpu)
checkpoint = torch.load(resuming_path, map_location=loc)
start_epoch = 0 if manual_reload and not fit_sefa else checkpoint["epoch"]
if manual_reload_dir is not None and 'parameter_estimation' in manual_reload_dir:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint["model"].items():
name = 'module.' + k # add `module.`
new_state_dict[name] = v
network.load_state_dict(new_state_dict)
else:
network.load_state_dict(checkpoint["model"])
if not manual_reload:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
if gpu==0:
# print("=> loaded checkpoint '{}' (epoch {})".format(resuming_path, checkpoint['epoch']))
print("=> loaded checkpoint '{}' (epoch {})".format(resuming_path, epoch))
return start_epoch
# find best epoch for reloading current model
def find_best_epoch(input_dir):
cur_epochs = glob("{}*".format(input_dir))
return find_by_name(cur_epochs)
# sort string epoch names by integers
def find_by_name(epochs):
int_epochs = []
for e in epochs:
int_epochs.append(int(os.path.basename(e)))
int_epochs.sort()
return (int_epochs[-1])
# remove ckpt files
def remove_ckpt(cur_ckpt_path_dir, leave=2):
ckpt_nums = [int(i) for i in os.listdir(cur_ckpt_path_dir)]
ckpt_nums.sort()
del_num = len(ckpt_nums) - leave
cur_del_num = 0
while del_num > 0:
shutil.rmtree("{}{}".format(cur_ckpt_path_dir, ckpt_nums[cur_del_num]))
del_num -= 1
cur_del_num += 1
''' multi-GPU functions '''
# gather function implemented from DirectCLR
class GatherLayer_Direct(torch.autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""
@staticmethod
def forward(ctx, x):
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
dist.all_reduce(all_gradients)
return all_gradients[dist.get_rank()]
from classy_vision.generic.distributed_util import (
convert_to_distributed_tensor,
convert_to_normal_tensor,
is_distributed_training_run,
)
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
"""
Similar to classy_vision.generic.distributed_util.gather_from_all
except that it does not cut the gradients
"""
if tensor.ndim == 0:
# 0 dim tensors cannot be gathered. so unsqueeze
tensor = tensor.unsqueeze(0)
if is_distributed_training_run():
tensor, orig_device = convert_to_distributed_tensor(tensor)
gathered_tensors = GatherLayer_Direct.apply(tensor)
gathered_tensors = [
convert_to_normal_tensor(_tensor, orig_device)
for _tensor in gathered_tensors
]
else:
gathered_tensors = [tensor]
gathered_tensor = torch.cat(gathered_tensors, 0)
return gathered_tensor