File size: 3,383 Bytes
d6ee7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import yaml
from collections import OrderedDict, abc as container_abc
from utils.my_containers import ObjDict
from utils.utils import ordered_load


def load_yaml(config_path):
    with open(config_path, encoding='utf-8') as fp:
        cfg = ordered_load(fp, yaml.SafeLoader)
    return ObjDict(cfg)


class ConfigConstructor(object):
    def __init__(self, config_path):
        self.suffix2loadMethods = {'.yaml': load_yaml, '.yml': load_yaml}
        self.inherit_tree = {}
        self.config = self.load_config(config_path)

    def load_config(self, config_path):
        for suffix, m in self.suffix2loadMethods.items():
            if suffix in config_path:
                return ObjDict(m(config_path)).transform()
        raise NotImplementedError

    def config_inherit(self, base_config_path_list):
        if isinstance(base_config_path_list, str):
            base_config_path_list = [base_config_path_list]
        base_config = ObjDict()
        for base_cfg_path in base_config_path_list:
            self.cfg_update(base_config, ConfigConstructor(base_cfg_path).get_config())
        return base_config

    def construct_config(self, config_dict, kwargs=None):
        base_config = ObjDict()
        for key, value in config_dict.items():
            if key == '_Base_Config':
                base_config = self.config_inherit(config_dict['_Base_Config'])
            elif isinstance(value, container_abc.Mapping):
                config_dict[key] = self.construct_config(value)
        self.cfg_update(base_config, config_dict)
        if kwargs is not None:
            self.cfg_update(base_config, kwargs)
        if '_Base_Config' in base_config:
            base_config.pop('_Base_Config')
        return base_config

    def get_config(self, kwargs=None):
        cfg = self.construct_config(self.config, kwargs)
        return cfg

    def update_by_type(self, base_value, new_value):
        assert type(base_value) == type(new_value)
        if isinstance(new_value, container_abc.Mapping):
            base_value.update(new_value)
            return base_value
        if isinstance(base_value, list):
            base_value.extend(new_value)
            return base_value
        raise NotImplemented

    def cfg_update(self, base_cfg, new_cfg):
        if not new_cfg:
            base_cfg.clear()
        add_key = set()
        for key, value in new_cfg.items():
            if key[-1] == '*':
                ori_key = key[:-1]
                if ori_key not in base_cfg:
                    continue
                add_key.add(ori_key)
                new_cfg[ori_key] = new_cfg[key]
                new_cfg.pop(key)
        for key, value in new_cfg.items():
            if key not in base_cfg:
                base_cfg[key] = value
            elif isinstance(value, container_abc.Mapping):
                if 'name' in value and value['name'] != base_cfg[key].get('name', None):
                    if value['name'][-1] != '*':
                        base_cfg[key] = new_cfg[key]
                        continue
                    value['name'] = value['name'][:-1]
                self.cfg_update(base_cfg[key], new_cfg[key])
            else:
                if key in add_key:
                    base_cfg[key] = self.update_by_type(base_cfg[key], new_cfg[key])
                else:
                    base_cfg[key] = new_cfg[key]