|
from email.policy import strict |
|
import torch |
|
import torchvision.models |
|
import os.path as osp |
|
import copy |
|
from .utils import \ |
|
get_total_param, get_total_param_sum, \ |
|
get_unit |
|
|
|
|
|
def singleton(class_): |
|
instances = {} |
|
|
|
def getinstance(*args, **kwargs): |
|
if class_ not in instances: |
|
instances[class_] = class_(*args, **kwargs) |
|
return instances[class_] |
|
return getinstance |
|
|
|
|
|
def preprocess_model_args(args): |
|
|
|
|
|
|
|
args = copy.deepcopy(args) |
|
if 'layer_units' in args: |
|
layer_units = [ |
|
get_unit()(i) for i in args.layer_units |
|
] |
|
args.layer_units = layer_units |
|
if 'backbone' in args: |
|
args.backbone = get_model()(args.backbone) |
|
return args |
|
|
|
@singleton |
|
class get_model(object): |
|
def __init__(self): |
|
self.model = {} |
|
self.version = {} |
|
|
|
def register(self, model, name, version='x'): |
|
self.model[name] = model |
|
self.version[name] = version |
|
|
|
def __call__(self, cfg, verbose=True): |
|
""" |
|
Construct model based on the config. |
|
""" |
|
t = cfg.type |
|
|
|
|
|
if t.find('audioldm')==0: |
|
from ..latent_diffusion.vae import audioldm |
|
elif t.find('autoencoderkl')==0: |
|
from ..latent_diffusion.vae import autokl |
|
elif t.find('optimus')==0: |
|
from ..latent_diffusion.vae import optimus |
|
|
|
elif t.find('clip')==0: |
|
from ..encoders import clip |
|
elif t.find('clap')==0: |
|
from ..encoders import clap |
|
|
|
elif t.find('sd')==0: |
|
from .. import sd |
|
elif t.find('codi')==0: |
|
from .. import codi |
|
elif t.find('thesis_model')==0: |
|
from .. import codi_2 |
|
elif t.find('openai_unet')==0: |
|
from ..latent_diffusion import diffusion_unet |
|
elif t.find('prova')==0: |
|
from ..latent_diffusion import diffusion_unet |
|
|
|
args = preprocess_model_args(cfg.args) |
|
net = self.model[t](**args) |
|
|
|
return net |
|
|
|
def get_version(self, name): |
|
return self.version[name] |
|
|
|
|
|
def register(name, version='x'): |
|
def wrapper(class_): |
|
get_model().register(class_, name, version) |
|
return class_ |
|
return wrapper |
|
|