|
import os |
|
import yaml |
|
import json |
|
import pickle |
|
import torch |
|
|
|
def traverse_dir( |
|
root_dir, |
|
extensions, |
|
amount=None, |
|
str_include=None, |
|
str_exclude=None, |
|
is_pure=False, |
|
is_sort=False, |
|
is_ext=True): |
|
|
|
file_list = [] |
|
cnt = 0 |
|
for root, _, files in os.walk(root_dir): |
|
for file in files: |
|
if any([file.endswith(f".{ext}") for ext in extensions]): |
|
|
|
mix_path = os.path.join(root, file) |
|
pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path |
|
|
|
|
|
if (amount is not None) and (cnt == amount): |
|
if is_sort: |
|
file_list.sort() |
|
return file_list |
|
|
|
|
|
if (str_include is not None) and (str_include not in pure_path): |
|
continue |
|
if (str_exclude is not None) and (str_exclude in pure_path): |
|
continue |
|
|
|
if not is_ext: |
|
ext = pure_path.split('.')[-1] |
|
pure_path = pure_path[:-(len(ext)+1)] |
|
file_list.append(pure_path) |
|
cnt += 1 |
|
if is_sort: |
|
file_list.sort() |
|
return file_list |
|
|
|
|
|
|
|
class DotDict(dict): |
|
def __getattr__(*args): |
|
val = dict.get(*args) |
|
return DotDict(val) if type(val) is dict else val |
|
|
|
__setattr__ = dict.__setitem__ |
|
__delattr__ = dict.__delitem__ |
|
|
|
|
|
def get_network_paras_amount(model_dict): |
|
info = dict() |
|
for model_name, model in model_dict.items(): |
|
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
info[model_name] = trainable_params |
|
return info |
|
|
|
|
|
def load_config(path_config): |
|
with open(path_config, "r") as config: |
|
args = yaml.safe_load(config) |
|
args = DotDict(args) |
|
|
|
return args |
|
|
|
def save_config(path_config,config): |
|
config = dict(config) |
|
with open(path_config, "w") as f: |
|
yaml.dump(config, f) |
|
|
|
def to_json(path_params, path_json): |
|
params = torch.load(path_params, map_location=torch.device('cpu')) |
|
raw_state_dict = {} |
|
for k, v in params.items(): |
|
val = v.flatten().numpy().tolist() |
|
raw_state_dict[k] = val |
|
|
|
with open(path_json, 'w') as outfile: |
|
json.dump(raw_state_dict, outfile,indent= "\t") |
|
|
|
|
|
def convert_tensor_to_numpy(tensor, is_squeeze=True): |
|
if is_squeeze: |
|
tensor = tensor.squeeze() |
|
if tensor.requires_grad: |
|
tensor = tensor.detach() |
|
if tensor.is_cuda: |
|
tensor = tensor.cpu() |
|
return tensor.numpy() |
|
|
|
|
|
def load_model( |
|
expdir, |
|
model, |
|
optimizer, |
|
name='model', |
|
postfix='', |
|
device='cpu'): |
|
if postfix == '': |
|
postfix = '_' + postfix |
|
path = os.path.join(expdir, name+postfix) |
|
path_pt = traverse_dir(expdir, ['pt'], is_ext=False) |
|
global_step = 0 |
|
if len(path_pt) > 0: |
|
steps = [s[len(path):] for s in path_pt] |
|
maxstep = max([int(s) if s.isdigit() else 0 for s in steps]) |
|
if maxstep >= 0: |
|
path_pt = path+str(maxstep)+'.pt' |
|
else: |
|
path_pt = path+'best.pt' |
|
print(' [*] restoring model from', path_pt) |
|
ckpt = torch.load(path_pt, map_location=torch.device(device)) |
|
global_step = ckpt['global_step'] |
|
model.load_state_dict(ckpt['model'], strict=False) |
|
if ckpt.get('optimizer') != None: |
|
optimizer.load_state_dict(ckpt['optimizer']) |
|
return global_step, model, optimizer |
|
|