Spaces:
Runtime error
Runtime error
import torch | |
import os | |
import shutil | |
import numpy as np | |
def load_network_and_optimizer(net, opt, pretrained_dir, gpu, scaler=None): | |
pretrained = torch.load(pretrained_dir, | |
map_location=torch.device("cuda:" + str(gpu))) | |
pretrained_dict = pretrained['state_dict'] | |
model_dict = net.state_dict() | |
pretrained_dict_update = {} | |
pretrained_dict_remove = [] | |
for k, v in pretrained_dict.items(): | |
if k in model_dict: | |
pretrained_dict_update[k] = v | |
elif k[:7] == 'module.': | |
if k[7:] in model_dict: | |
pretrained_dict_update[k[7:]] = v | |
else: | |
pretrained_dict_remove.append(k) | |
model_dict.update(pretrained_dict_update) | |
net.load_state_dict(model_dict) | |
opt.load_state_dict(pretrained['optimizer']) | |
if scaler is not None and 'scaler' in pretrained.keys(): | |
scaler.load_state_dict(pretrained['scaler']) | |
del (pretrained) | |
return net.cuda(gpu), opt, pretrained_dict_remove | |
def load_network_and_optimizer_v2(net, opt, pretrained_dir, gpu, scaler=None): | |
pretrained = torch.load(pretrained_dir, | |
map_location=torch.device("cuda:" + str(gpu))) | |
# load model | |
pretrained_dict = pretrained['state_dict'] | |
model_dict = net.state_dict() | |
pretrained_dict_update = {} | |
pretrained_dict_remove = [] | |
for k, v in pretrained_dict.items(): | |
if k in model_dict: | |
pretrained_dict_update[k] = v | |
elif k[:7] == 'module.': | |
if k[7:] in model_dict: | |
pretrained_dict_update[k[7:]] = v | |
else: | |
pretrained_dict_remove.append(k) | |
model_dict.update(pretrained_dict_update) | |
net.load_state_dict(model_dict) | |
# load optimizer | |
opt_dict = opt.state_dict() | |
all_params = { | |
param_group['name']: param_group['params'][0] | |
for param_group in opt_dict['param_groups'] | |
} | |
pretrained_opt_dict = {'state': {}, 'param_groups': []} | |
for idx in range(len(pretrained['optimizer']['param_groups'])): | |
param_group = pretrained['optimizer']['param_groups'][idx] | |
if param_group['name'] in all_params.keys(): | |
pretrained_opt_dict['state'][all_params[ | |
param_group['name']]] = pretrained['optimizer']['state'][ | |
param_group['params'][0]] | |
param_group['params'][0] = all_params[param_group['name']] | |
pretrained_opt_dict['param_groups'].append(param_group) | |
opt_dict.update(pretrained_opt_dict) | |
opt.load_state_dict(opt_dict) | |
# load scaler | |
if scaler is not None and 'scaler' in pretrained.keys(): | |
scaler.load_state_dict(pretrained['scaler']) | |
del (pretrained) | |
return net.cuda(gpu), opt, pretrained_dict_remove | |
def load_network(net, pretrained_dir, gpu): | |
pretrained = torch.load(pretrained_dir, | |
map_location=torch.device("cuda:" + str(gpu))) | |
if 'state_dict' in pretrained.keys(): | |
pretrained_dict = pretrained['state_dict'] | |
elif 'model' in pretrained.keys(): | |
pretrained_dict = pretrained['model'] | |
else: | |
pretrained_dict = pretrained | |
model_dict = net.state_dict() | |
pretrained_dict_update = {} | |
pretrained_dict_remove = [] | |
for k, v in pretrained_dict.items(): | |
if k in model_dict: | |
pretrained_dict_update[k] = v | |
elif k[:7] == 'module.': | |
if k[7:] in model_dict: | |
pretrained_dict_update[k[7:]] = v | |
else: | |
pretrained_dict_remove.append(k) | |
model_dict.update(pretrained_dict_update) | |
net.load_state_dict(model_dict) | |
del (pretrained) | |
return net.cuda(gpu), pretrained_dict_remove | |
def save_network(net, | |
opt, | |
step, | |
save_path, | |
max_keep=8, | |
backup_dir='./saved_models', | |
scaler=None): | |
ckpt = {'state_dict': net.state_dict(), 'optimizer': opt.state_dict()} | |
if scaler is not None: | |
ckpt['scaler'] = scaler.state_dict() | |
try: | |
if not os.path.exists(save_path): | |
os.makedirs(save_path) | |
save_file = 'save_step_%s.pth' % (step) | |
save_dir = os.path.join(save_path, save_file) | |
torch.save(ckpt, save_dir) | |
except: | |
save_path = backup_dir | |
if not os.path.exists(save_path): | |
os.makedirs(save_path) | |
save_file = 'save_step_%s.pth' % (step) | |
save_dir = os.path.join(save_path, save_file) | |
torch.save(ckpt, save_dir) | |
all_ckpt = os.listdir(save_path) | |
if len(all_ckpt) > max_keep: | |
all_step = [] | |
for ckpt_name in all_ckpt: | |
step = int(ckpt_name.split('_')[-1].split('.')[0]) | |
all_step.append(step) | |
all_step = list(np.sort(all_step))[:-max_keep] | |
for step in all_step: | |
ckpt_path = os.path.join(save_path, 'save_step_%s.pth' % (step)) | |
os.system('rm {}'.format(ckpt_path)) | |
def cp_ckpt(remote_dir="data_wd/youtube_vos_jobs/result", curr_dir="backup"): | |
exps = os.listdir(curr_dir) | |
for exp in exps: | |
exp_dir = os.path.join(curr_dir, exp) | |
stages = os.listdir(exp_dir) | |
for stage in stages: | |
stage_dir = os.path.join(exp_dir, stage) | |
finals = ["ema_ckpt", "ckpt"] | |
for final in finals: | |
final_dir = os.path.join(stage_dir, final) | |
ckpts = os.listdir(final_dir) | |
for ckpt in ckpts: | |
if '.pth' not in ckpt: | |
continue | |
curr_ckpt_path = os.path.join(final_dir, ckpt) | |
remote_ckpt_path = os.path.join(remote_dir, exp, stage, | |
final, ckpt) | |
if os.path.exists(remote_ckpt_path): | |
os.system('rm {}'.format(remote_ckpt_path)) | |
try: | |
shutil.copy(curr_ckpt_path, remote_ckpt_path) | |
print("Copy {} to {}.".format(curr_ckpt_path, | |
remote_ckpt_path)) | |
except OSError as Inst: | |
return | |