Spaces:
Sleeping
Sleeping
File size: 4,157 Bytes
6065472 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import importlib
import os
import sys
from typing import Callable, Dict, Union
import numpy as np
import yaml
import torch
def merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v
def load_config(config_file):
with open(config_file, "r") as reader:
config = yaml.load(reader, Loader=yaml.FullLoader)
if "inherit_from" in config:
base_config_file = config["inherit_from"]
base_config_file = os.path.join(
os.path.dirname(config_file), base_config_file
)
assert not os.path.samefile(config_file, base_config_file), \
"inherit from itself"
base_config = load_config(base_config_file)
del config["inherit_from"]
merge_a_into_b(config, base_config)
return base_config
return config
def get_cls_from_str(string, reload=False):
module_name, cls_name = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module_name)
importlib.reload(module_imp)
return getattr(importlib.import_module(module_name, package=None), cls_name)
def init_obj_from_dict(config, **kwargs):
obj_args = config["args"].copy()
obj_args.update(kwargs)
for k in config:
if k not in ["type", "args"] and isinstance(config[k], dict) and k not in kwargs:
obj_args[k] = init_obj_from_dict(config[k])
try:
obj = get_cls_from_str(config["type"])(**obj_args)
return obj
except Exception as e:
print(f"Initializing {config} failed, detailed error stack: ")
raise e
def init_model_from_config(config, print_fn=sys.stdout.write):
kwargs = {}
for k in config:
if k not in ["type", "args", "pretrained"]:
sub_model = init_model_from_config(config[k], print_fn)
if "pretrained" in config[k]:
load_pretrained_model(sub_model,
config[k]["pretrained"],
print_fn)
kwargs[k] = sub_model
model = init_obj_from_dict(config, **kwargs)
return model
def merge_load_state_dict(state_dict,
model: torch.nn.Module,
output_fn: Callable = sys.stdout.write):
model_dict = model.state_dict()
pretrained_dict = {}
mismatch_keys = []
for key, value in state_dict.items():
if key in model_dict and model_dict[key].shape == value.shape:
pretrained_dict[key] = value
else:
mismatch_keys.append(key)
output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, strict=True)
return pretrained_dict.keys()
def load_pretrained_model(model: torch.nn.Module,
pretrained: Union[str, Dict],
output_fn: Callable = sys.stdout.write):
if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
output_fn(f"pretrained {pretrained} not exist!")
return
if hasattr(model, "load_pretrained"):
model.load_pretrained(pretrained, output_fn)
return
if isinstance(pretrained, dict):
state_dict = pretrained
else:
state_dict = torch.load(pretrained, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
merge_load_state_dict(state_dict, model, output_fn)
def pad_sequence(data, pad_value=0):
if isinstance(data[0], (np.ndarray, torch.Tensor)):
data = [torch.as_tensor(arr) for arr in data]
padded_seq = torch.nn.utils.rnn.pad_sequence(data,
batch_first=True,
padding_value=pad_value)
length = np.array([x.shape[0] for x in data])
return padded_seq, length |