File size: 4,157 Bytes
6065472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import importlib
import os
import sys
from typing import Callable, Dict, Union

import numpy as np
import yaml
import torch


def merge_a_into_b(a, b):
    # merge dict a into dict b. values in a will overwrite b.
    for k, v in a.items():
        if isinstance(v, dict) and k in b:
            assert isinstance(
                b[k], dict
            ), "Cannot inherit key '{}' from base!".format(k)
            merge_a_into_b(v, b[k])
        else:
            b[k] = v


def load_config(config_file):
    with open(config_file, "r") as reader:
        config = yaml.load(reader, Loader=yaml.FullLoader)
    if "inherit_from" in config:
        base_config_file = config["inherit_from"]
        base_config_file = os.path.join(
            os.path.dirname(config_file), base_config_file
        )
        assert not os.path.samefile(config_file, base_config_file), \
            "inherit from itself"
        base_config = load_config(base_config_file)
        del config["inherit_from"]
        merge_a_into_b(config, base_config)
        return base_config
    return config

def get_cls_from_str(string, reload=False):
    module_name, cls_name = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module_name)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module_name, package=None), cls_name)

def init_obj_from_dict(config, **kwargs):
    obj_args = config["args"].copy()
    obj_args.update(kwargs)
    for k in config:
        if k not in ["type", "args"] and isinstance(config[k], dict) and k not in kwargs:
            obj_args[k] = init_obj_from_dict(config[k])
    try:
        obj = get_cls_from_str(config["type"])(**obj_args)
        return obj
    except Exception as e:
        print(f"Initializing {config} failed, detailed error stack: ")
        raise e

def init_model_from_config(config, print_fn=sys.stdout.write):
    kwargs = {}
    for k in config:
        if k not in ["type", "args", "pretrained"]:
            sub_model = init_model_from_config(config[k], print_fn)
            if "pretrained" in config[k]:
                load_pretrained_model(sub_model,
                                      config[k]["pretrained"],
                                      print_fn)
            kwargs[k] = sub_model
    model = init_obj_from_dict(config, **kwargs)
    return model

def merge_load_state_dict(state_dict,
                          model: torch.nn.Module,
                          output_fn: Callable = sys.stdout.write):
    model_dict = model.state_dict()
    pretrained_dict = {}
    mismatch_keys = []
    for key, value in state_dict.items():
        if key in model_dict and model_dict[key].shape == value.shape:
            pretrained_dict[key] = value
        else:
            mismatch_keys.append(key)
    output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict, strict=True)
    return pretrained_dict.keys()


def load_pretrained_model(model: torch.nn.Module,
                          pretrained: Union[str, Dict],
                          output_fn: Callable = sys.stdout.write):
    if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
        output_fn(f"pretrained {pretrained} not exist!")
        return
    
    if hasattr(model, "load_pretrained"):
        model.load_pretrained(pretrained, output_fn)
        return

    if isinstance(pretrained, dict):
        state_dict = pretrained
    else:
        state_dict = torch.load(pretrained, map_location="cpu")

    if "model" in state_dict:
        state_dict = state_dict["model"]
    
    merge_load_state_dict(state_dict, model, output_fn)

def pad_sequence(data, pad_value=0):
    if isinstance(data[0], (np.ndarray, torch.Tensor)):
        data = [torch.as_tensor(arr) for arr in data]
    padded_seq = torch.nn.utils.rnn.pad_sequence(data,
                                                 batch_first=True,
                                                 padding_value=pad_value)
    length = np.array([x.shape[0] for x in data])
    return padded_seq, length