File size: 3,765 Bytes
fb53ec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from email.policy import strict
import torch
import torchvision.models
import os.path as osp
import copy
from ...log_service import print_log 
from .utils import \
    get_total_param, get_total_param_sum, \
    get_unit

# def load_state_dict(net, model_path):
#     if isinstance(net, dict):
#         for ni, neti in net.items():
#             paras = torch.load(model_path[ni], map_location=torch.device('cpu'))
#             new_paras = neti.state_dict()
#             new_paras.update(paras)
#             neti.load_state_dict(new_paras)
#     else:
#         paras = torch.load(model_path, map_location=torch.device('cpu'))
#         new_paras = net.state_dict()
#         new_paras.update(paras)
#         net.load_state_dict(new_paras)
#     return

# def save_state_dict(net, path):
#     if isinstance(net, (torch.nn.DataParallel,
#                         torch.nn.parallel.DistributedDataParallel)):
#         torch.save(net.module.state_dict(), path)
#     else:
#         torch.save(net.state_dict(), path)

def singleton(class_):
    instances = {}
    def getinstance(*args, **kwargs):
        if class_ not in instances:
            instances[class_] = class_(*args, **kwargs)
        return instances[class_]
    return getinstance

def preprocess_model_args(args):
    # If args has layer_units, get the corresponding
    #     units.
    # If args get backbone, get the backbone model.
    args = copy.deepcopy(args)
    if 'layer_units' in args:
        layer_units = [
            get_unit()(i) for i in args.layer_units
        ]
        args.layer_units = layer_units
    if 'backbone' in args:
        args.backbone = get_model()(args.backbone)
    return args

@singleton
class get_model(object):
    def __init__(self):
        self.model = {}
        self.version = {}

    def register(self, model, name, version='x'):
        self.model[name] = model
        self.version[name] = version

    def __call__(self, cfg, verbose=True):
        """
        Construct model based on the config. 
        """
        t = cfg.type

        # the register is in each file
        if t.find('ldm')==0:
            from .. import ldm
        elif t=='autoencoderkl':
            from .. import autoencoder
        elif t.find('clip')==0:
            from .. import clip
        elif t.find('sd')==0:
            from .. import sd
        elif t.find('vd')==0:
            from .. import vd
        elif t.find('openai_unet')==0:
            from .. import openaimodel
        elif t.find('optimus')==0:
            from .. import optimus

        args = preprocess_model_args(cfg.args)
        net = self.model[t](**args)

        if 'ckpt' in cfg:
            checkpoint = torch.load(cfg.ckpt, map_location='cpu')
            strict_sd = cfg.get('strict_sd', True)
            net.load_state_dict(checkpoint['state_dict'], strict=strict_sd)
            if verbose:
                print_log('Load ckpt from {}'.format(cfg.ckpt))
        elif 'pth' in cfg:
            sd = torch.load(cfg.pth, map_location='cpu')
            strict_sd = cfg.get('strict_sd', True)
            net.load_state_dict(sd, strict=strict_sd)
            if verbose:
                print_log('Load pth from {}'.format(cfg.pth))

        # display param_num & param_sum
        if verbose:
            print_log(
                'Load {} with total {} parameters,' 
                '{:.3f} parameter sum.'.format(
                    t, 
                    get_total_param(net), 
                    get_total_param_sum(net) ))

        return net

    def get_version(self, name):
        return self.version[name]

def register(name, version='x'):
    def wrapper(class_):
        get_model().register(class_, name, version)
        return class_
    return wrapper