|
|
from email.policy import strict |
|
|
import torch |
|
|
import torchvision.models |
|
|
import os.path as osp |
|
|
import copy |
|
|
from ...log_service import print_log |
|
|
from .utils import \ |
|
|
get_total_param, get_total_param_sum, \ |
|
|
get_unit |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def singleton(class_): |
|
|
instances = {} |
|
|
def getinstance(*args, **kwargs): |
|
|
if class_ not in instances: |
|
|
instances[class_] = class_(*args, **kwargs) |
|
|
return instances[class_] |
|
|
return getinstance |
|
|
|
|
|
def preprocess_model_args(args): |
|
|
|
|
|
|
|
|
|
|
|
args = copy.deepcopy(args) |
|
|
if 'layer_units' in args: |
|
|
layer_units = [ |
|
|
get_unit()(i) for i in args.layer_units |
|
|
] |
|
|
args.layer_units = layer_units |
|
|
if 'backbone' in args: |
|
|
args.backbone = get_model()(args.backbone) |
|
|
return args |
|
|
|
|
|
@singleton |
|
|
class get_model(object): |
|
|
def __init__(self): |
|
|
self.model = {} |
|
|
self.version = {} |
|
|
|
|
|
def register(self, model, name, version='x'): |
|
|
self.model[name] = model |
|
|
self.version[name] = version |
|
|
|
|
|
def __call__(self, cfg, verbose=True): |
|
|
""" |
|
|
Construct model based on the config. |
|
|
""" |
|
|
t = cfg.type |
|
|
|
|
|
|
|
|
if t.find('ldm')==0: |
|
|
from .. import ldm |
|
|
elif t=='autoencoderkl': |
|
|
from .. import autoencoder |
|
|
elif t.find('clip')==0: |
|
|
from .. import clip |
|
|
elif t.find('sd')==0: |
|
|
from .. import sd |
|
|
elif t.find('vd')==0: |
|
|
from .. import vd |
|
|
elif t.find('openai_unet')==0: |
|
|
from .. import openaimodel |
|
|
elif t.find('optimus')==0: |
|
|
from .. import optimus |
|
|
|
|
|
args = preprocess_model_args(cfg.args) |
|
|
net = self.model[t](**args) |
|
|
|
|
|
map_location = cfg.get('map_location', 'cpu') |
|
|
strict_sd = cfg.get('strict_sd', True) |
|
|
if 'ckpt' in cfg: |
|
|
checkpoint = torch.load(cfg.ckpt, map_location=map_location) |
|
|
net.load_state_dict(checkpoint['state_dict'], strict=strict_sd) |
|
|
if verbose: |
|
|
print_log('Load ckpt from {}'.format(cfg.ckpt)) |
|
|
elif 'pth' in cfg: |
|
|
sd = torch.load(cfg.pth, map_location=map_location) |
|
|
net.load_state_dict(sd, strict=strict_sd) |
|
|
if verbose: |
|
|
print_log('Load pth from {}'.format(cfg.pth)) |
|
|
|
|
|
|
|
|
if verbose: |
|
|
print_log( |
|
|
'Load {} with total {} parameters,' |
|
|
'{:.3f} parameter sum.'.format( |
|
|
t, |
|
|
get_total_param(net), |
|
|
get_total_param_sum(net) )) |
|
|
|
|
|
return net |
|
|
|
|
|
def get_version(self, name): |
|
|
return self.version[name] |
|
|
|
|
|
def register(name, version='x'): |
|
|
def wrapper(class_): |
|
|
get_model().register(class_, name, version) |
|
|
return class_ |
|
|
return wrapper |
|
|
|