TSEditor / utils /io_utils.py
PeterYu's picture
update
2875fe6
import os
import sys
import yaml
import json
import torch
import random
import warnings
import importlib
import numpy as np
def load_yaml_config(path):
with open(path) as f:
config = yaml.full_load(f)
return config
def save_config_to_yaml(config, path):
assert path.endswith(".yaml")
with open(path, "w") as f:
f.write(yaml.dump(config))
f.close()
def save_dict_to_json(d, path, indent=None):
json.dump(d, open(path, "w"), indent=indent)
def load_dict_from_json(path):
return json.load(open(path, "r"))
def write_args(args, path):
args_dict = dict(
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
)
with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write(
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
)
args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv))
args_file.write("\n==> args:\n")
for k, v in sorted(args_dict.items()):
args_file.write(" %s: %s\n" % (str(k), str(v)))
args_file.close()
def seed_everything(seed, cudnn_deterministic=False):
"""
Function that sets seed for pseudo-random number generators in:
pytorch, numpy, python.random
Args:
seed: the integer value seed for global random state
"""
if seed is not None:
print(f"Global seed set to {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = False
if cudnn_deterministic:
torch.backends.cudnn.deterministic = True
warnings.warn(
"You have chosen to seed training. "
"This will turn on the CUDNN deterministic setting, "
"which can slow down your training considerably! "
"You may see unexpected behavior when restarting "
"from checkpoints."
)
def merge_opts_to_config(config, opts):
def modify_dict(c, nl, v):
if len(nl) == 1:
c[nl[0]] = type(c[nl[0]])(v)
else:
# print(nl)
c[nl[0]] = modify_dict(c[nl[0]], nl[1:], v)
return c
if opts is not None and len(opts) > 0:
assert (
len(opts) % 2 == 0
), "each opts should be given by the name and values! The length shall be even number!"
for i in range(len(opts) // 2):
name = opts[2 * i]
value = opts[2 * i + 1]
config = modify_dict(config, name.split("."), value)
return config
def modify_config_for_debug(config):
config["dataloader"]["num_workers"] = 0
config["dataloader"]["batch_size"] = 1
return config
def get_model_parameters_info(model):
# for mn, m in model.named_modules():
parameters = {"overall": {"trainable": 0, "non_trainable": 0, "total": 0}}
for child_name, child_module in model.named_children():
parameters[child_name] = {"trainable": 0, "non_trainable": 0}
for pn, p in child_module.named_parameters():
if p.requires_grad:
parameters[child_name]["trainable"] += p.numel()
else:
parameters[child_name]["non_trainable"] += p.numel()
parameters[child_name]["total"] = (
parameters[child_name]["trainable"]
+ parameters[child_name]["non_trainable"]
)
parameters["overall"]["trainable"] += parameters[child_name]["trainable"]
parameters["overall"]["non_trainable"] += parameters[child_name][
"non_trainable"
]
parameters["overall"]["total"] += parameters[child_name]["total"]
# format the numbers
def format_number(num):
K = 2**10
M = 2**20
G = 2**30
if num > G: # K
uint = "G"
num = round(float(num) / G, 2)
elif num > M:
uint = "M"
num = round(float(num) / M, 2)
elif num > K:
uint = "K"
num = round(float(num) / K, 2)
else:
uint = ""
return "{}{}".format(num, uint)
def format_dict(d):
for k, v in d.items():
if isinstance(v, dict):
format_dict(v)
else:
d[k] = format_number(v)
format_dict(parameters)
return parameters
def format_seconds(seconds):
h = int(seconds // 3600)
m = int(seconds // 60 - h * 60)
s = int(seconds % 60)
d = int(h // 24)
h = h - d * 24
if d == 0:
if h == 0:
if m == 0:
ft = "{:02d}s".format(s)
else:
ft = "{:02d}m:{:02d}s".format(m, s)
else:
ft = "{:02d}h:{:02d}m:{:02d}s".format(h, m, s)
else:
ft = "{:d}d:{:02d}h:{:02d}m:{:02d}s".format(d, h, m, s)
return ft
def instantiate_from_config(config):
if config is None:
return None
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
module, cls = config["target"].rsplit(".", 1)
cls = getattr(importlib.import_module(module, package=None), cls)
return cls(**config.get("params", dict()))
def class_from_string(class_name):
module, cls = class_name.rsplit(".", 1)
cls = getattr(importlib.import_module(module, package=None), cls)
return cls
def get_all_file(dir, end_with=".h5"):
if isinstance(end_with, str):
end_with = [end_with]
filenames = []
for root, dirs, files in os.walk(dir):
for f in files:
for ew in end_with:
if f.endswith(ew):
filenames.append(os.path.join(root, f))
break
return filenames
def get_sub_dirs(dir, abs=True):
sub_dirs = os.listdir(dir)
if abs:
sub_dirs = [os.path.join(dir, s) for s in sub_dirs]
return sub_dirs
def get_model_buffer(model):
state_dict = model.state_dict()
buffers_ = {}
params_ = {n: p for n, p in model.named_parameters()}
for k in state_dict:
if k not in params_:
buffers_[k] = state_dict[k]
return buffers_