Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import yaml | |
import json | |
import torch | |
import random | |
import warnings | |
import importlib | |
import numpy as np | |
def load_yaml_config(path): | |
with open(path) as f: | |
config = yaml.full_load(f) | |
return config | |
def save_config_to_yaml(config, path): | |
assert path.endswith(".yaml") | |
with open(path, "w") as f: | |
f.write(yaml.dump(config)) | |
f.close() | |
def save_dict_to_json(d, path, indent=None): | |
json.dump(d, open(path, "w"), indent=indent) | |
def load_dict_from_json(path): | |
return json.load(open(path, "r")) | |
def write_args(args, path): | |
args_dict = dict( | |
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_") | |
) | |
with open(path, "a") as args_file: | |
args_file.write("==> torch version: {}\n".format(torch.__version__)) | |
args_file.write( | |
"==> cudnn version: {}\n".format(torch.backends.cudnn.version()) | |
) | |
args_file.write("==> Cmd:\n") | |
args_file.write(str(sys.argv)) | |
args_file.write("\n==> args:\n") | |
for k, v in sorted(args_dict.items()): | |
args_file.write(" %s: %s\n" % (str(k), str(v))) | |
args_file.close() | |
def seed_everything(seed, cudnn_deterministic=False): | |
""" | |
Function that sets seed for pseudo-random number generators in: | |
pytorch, numpy, python.random | |
Args: | |
seed: the integer value seed for global random state | |
""" | |
if seed is not None: | |
print(f"Global seed set to {seed}") | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = False | |
if cudnn_deterministic: | |
torch.backends.cudnn.deterministic = True | |
warnings.warn( | |
"You have chosen to seed training. " | |
"This will turn on the CUDNN deterministic setting, " | |
"which can slow down your training considerably! " | |
"You may see unexpected behavior when restarting " | |
"from checkpoints." | |
) | |
def merge_opts_to_config(config, opts): | |
def modify_dict(c, nl, v): | |
if len(nl) == 1: | |
c[nl[0]] = type(c[nl[0]])(v) | |
else: | |
# print(nl) | |
c[nl[0]] = modify_dict(c[nl[0]], nl[1:], v) | |
return c | |
if opts is not None and len(opts) > 0: | |
assert ( | |
len(opts) % 2 == 0 | |
), "each opts should be given by the name and values! The length shall be even number!" | |
for i in range(len(opts) // 2): | |
name = opts[2 * i] | |
value = opts[2 * i + 1] | |
config = modify_dict(config, name.split("."), value) | |
return config | |
def modify_config_for_debug(config): | |
config["dataloader"]["num_workers"] = 0 | |
config["dataloader"]["batch_size"] = 1 | |
return config | |
def get_model_parameters_info(model): | |
# for mn, m in model.named_modules(): | |
parameters = {"overall": {"trainable": 0, "non_trainable": 0, "total": 0}} | |
for child_name, child_module in model.named_children(): | |
parameters[child_name] = {"trainable": 0, "non_trainable": 0} | |
for pn, p in child_module.named_parameters(): | |
if p.requires_grad: | |
parameters[child_name]["trainable"] += p.numel() | |
else: | |
parameters[child_name]["non_trainable"] += p.numel() | |
parameters[child_name]["total"] = ( | |
parameters[child_name]["trainable"] | |
+ parameters[child_name]["non_trainable"] | |
) | |
parameters["overall"]["trainable"] += parameters[child_name]["trainable"] | |
parameters["overall"]["non_trainable"] += parameters[child_name][ | |
"non_trainable" | |
] | |
parameters["overall"]["total"] += parameters[child_name]["total"] | |
# format the numbers | |
def format_number(num): | |
K = 2**10 | |
M = 2**20 | |
G = 2**30 | |
if num > G: # K | |
uint = "G" | |
num = round(float(num) / G, 2) | |
elif num > M: | |
uint = "M" | |
num = round(float(num) / M, 2) | |
elif num > K: | |
uint = "K" | |
num = round(float(num) / K, 2) | |
else: | |
uint = "" | |
return "{}{}".format(num, uint) | |
def format_dict(d): | |
for k, v in d.items(): | |
if isinstance(v, dict): | |
format_dict(v) | |
else: | |
d[k] = format_number(v) | |
format_dict(parameters) | |
return parameters | |
def format_seconds(seconds): | |
h = int(seconds // 3600) | |
m = int(seconds // 60 - h * 60) | |
s = int(seconds % 60) | |
d = int(h // 24) | |
h = h - d * 24 | |
if d == 0: | |
if h == 0: | |
if m == 0: | |
ft = "{:02d}s".format(s) | |
else: | |
ft = "{:02d}m:{:02d}s".format(m, s) | |
else: | |
ft = "{:02d}h:{:02d}m:{:02d}s".format(h, m, s) | |
else: | |
ft = "{:d}d:{:02d}h:{:02d}m:{:02d}s".format(d, h, m, s) | |
return ft | |
def instantiate_from_config(config): | |
if config is None: | |
return None | |
if not "target" in config: | |
raise KeyError("Expected key `target` to instantiate.") | |
module, cls = config["target"].rsplit(".", 1) | |
cls = getattr(importlib.import_module(module, package=None), cls) | |
return cls(**config.get("params", dict())) | |
def class_from_string(class_name): | |
module, cls = class_name.rsplit(".", 1) | |
cls = getattr(importlib.import_module(module, package=None), cls) | |
return cls | |
def get_all_file(dir, end_with=".h5"): | |
if isinstance(end_with, str): | |
end_with = [end_with] | |
filenames = [] | |
for root, dirs, files in os.walk(dir): | |
for f in files: | |
for ew in end_with: | |
if f.endswith(ew): | |
filenames.append(os.path.join(root, f)) | |
break | |
return filenames | |
def get_sub_dirs(dir, abs=True): | |
sub_dirs = os.listdir(dir) | |
if abs: | |
sub_dirs = [os.path.join(dir, s) for s in sub_dirs] | |
return sub_dirs | |
def get_model_buffer(model): | |
state_dict = model.state_dict() | |
buffers_ = {} | |
params_ = {n: p for n, p in model.named_parameters()} | |
for k in state_dict: | |
if k not in params_: | |
buffers_[k] = state_dict[k] | |
return buffers_ | |