|
""" Utility file for trainers """ |
|
import os |
|
import shutil |
|
from glob import glob |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
|
|
|
|
''' checkpoint functions ''' |
|
|
|
def save_checkpoint(model, \ |
|
optimizer, \ |
|
scheduler, \ |
|
epoch, \ |
|
checkpoint_dir, \ |
|
name, \ |
|
model_name): |
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
checkpoint_state = { |
|
"model": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"scheduler": scheduler.state_dict(), |
|
"epoch": epoch |
|
} |
|
checkpoint_path = os.path.join(checkpoint_dir,'{}_{}_{}.pt'.format(name, model_name, epoch)) |
|
torch.save(checkpoint_state, checkpoint_path) |
|
print("Saved checkpoint: {}".format(checkpoint_path)) |
|
|
|
|
|
|
|
def reload_ckpt(args, \ |
|
network, \ |
|
optimizer, \ |
|
scheduler, \ |
|
gpu, \ |
|
model_name, \ |
|
manual_reload_name=None, \ |
|
manual_reload=False, \ |
|
manual_reload_dir=None, \ |
|
epoch=None, \ |
|
fit_sefa=False): |
|
if manual_reload: |
|
reload_name = manual_reload_name |
|
else: |
|
reload_name = args.name |
|
if manual_reload_dir: |
|
ckpt_dir = manual_reload_dir + reload_name + "/ckpt/" |
|
else: |
|
ckpt_dir = args.output_dir + reload_name + "/ckpt/" |
|
temp_ckpt_dir = f'{args.output_dir}{reload_name}/ckpt_temp/' |
|
reload_epoch = epoch |
|
|
|
if epoch==None: |
|
reload_epoch_temp = 0 |
|
reload_epoch_ckpt = 0 |
|
if len(os.listdir(temp_ckpt_dir))!=0: |
|
reload_epoch_temp = find_best_epoch(temp_ckpt_dir) |
|
if len(os.listdir(ckpt_dir))!=0: |
|
reload_epoch_ckpt = find_best_epoch(ckpt_dir) |
|
if reload_epoch_ckpt >= reload_epoch_temp: |
|
reload_epoch = reload_epoch_ckpt |
|
else: |
|
reload_epoch = reload_epoch_temp |
|
ckpt_dir = temp_ckpt_dir |
|
else: |
|
if os.path.isfile(f"{temp_ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt"): |
|
ckpt_dir = temp_ckpt_dir |
|
|
|
if model_name==None: |
|
resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{reload_epoch}.pt" |
|
else: |
|
resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt" |
|
if gpu==0: |
|
print("===Resume checkpoint from: {}===".format(resuming_path)) |
|
loc = 'cuda:{}'.format(gpu) |
|
checkpoint = torch.load(resuming_path, map_location=loc) |
|
start_epoch = 0 if manual_reload and not fit_sefa else checkpoint["epoch"] |
|
|
|
if manual_reload_dir is not None and 'parameter_estimation' in manual_reload_dir: |
|
from collections import OrderedDict |
|
new_state_dict = OrderedDict() |
|
for k, v in checkpoint["model"].items(): |
|
name = 'module.' + k |
|
new_state_dict[name] = v |
|
network.load_state_dict(new_state_dict) |
|
else: |
|
network.load_state_dict(checkpoint["model"]) |
|
if not manual_reload: |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
scheduler.load_state_dict(checkpoint["scheduler"]) |
|
if gpu==0: |
|
|
|
print("=> loaded checkpoint '{}' (epoch {})".format(resuming_path, epoch)) |
|
return start_epoch |
|
|
|
|
|
|
|
def find_best_epoch(input_dir): |
|
cur_epochs = glob("{}*".format(input_dir)) |
|
return find_by_name(cur_epochs) |
|
|
|
|
|
|
|
def find_by_name(epochs): |
|
int_epochs = [] |
|
for e in epochs: |
|
int_epochs.append(int(os.path.basename(e))) |
|
int_epochs.sort() |
|
return (int_epochs[-1]) |
|
|
|
|
|
|
|
def remove_ckpt(cur_ckpt_path_dir, leave=2): |
|
ckpt_nums = [int(i) for i in os.listdir(cur_ckpt_path_dir)] |
|
ckpt_nums.sort() |
|
del_num = len(ckpt_nums) - leave |
|
cur_del_num = 0 |
|
while del_num > 0: |
|
shutil.rmtree("{}{}".format(cur_ckpt_path_dir, ckpt_nums[cur_del_num])) |
|
del_num -= 1 |
|
cur_del_num += 1 |
|
|
|
|
|
|
|
''' multi-GPU functions ''' |
|
|
|
|
|
class GatherLayer_Direct(torch.autograd.Function): |
|
""" |
|
Gather tensors from all workers with support for backward propagation: |
|
This implementation does not cut the gradients as torch.distributed.all_gather does. |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, x): |
|
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] |
|
dist.all_gather(output, x) |
|
return tuple(output) |
|
|
|
@staticmethod |
|
def backward(ctx, *grads): |
|
all_gradients = torch.stack(grads) |
|
dist.all_reduce(all_gradients) |
|
return all_gradients[dist.get_rank()] |
|
|
|
from classy_vision.generic.distributed_util import ( |
|
convert_to_distributed_tensor, |
|
convert_to_normal_tensor, |
|
is_distributed_training_run, |
|
) |
|
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Similar to classy_vision.generic.distributed_util.gather_from_all |
|
except that it does not cut the gradients |
|
""" |
|
if tensor.ndim == 0: |
|
|
|
tensor = tensor.unsqueeze(0) |
|
|
|
if is_distributed_training_run(): |
|
tensor, orig_device = convert_to_distributed_tensor(tensor) |
|
gathered_tensors = GatherLayer_Direct.apply(tensor) |
|
gathered_tensors = [ |
|
convert_to_normal_tensor(_tensor, orig_device) |
|
for _tensor in gathered_tensors |
|
] |
|
else: |
|
gathered_tensors = [tensor] |
|
gathered_tensor = torch.cat(gathered_tensors, 0) |
|
return gathered_tensor |
|
|
|
|
|
|