Spaces:
Runtime error
Runtime error
from importlib import import_module | |
import torch | |
from loguru import logger | |
from df_local.config import DfParams, config | |
class ModelParams(DfParams): | |
def __init__(self): | |
self.__model = config("MODEL", default="deepfilternet", section="train") | |
self.__params = getattr(import_module("df_local." + self.__model), "ModelParams")() | |
def __getattr__(self, attr: str): | |
return getattr(self.__params, attr) | |
def init_model(*args, **kwargs): | |
"""Initialize the model specified in the config.""" | |
model = config("MODEL", default="deepfilternet", section="train") | |
logger.info(f"Initializing model `{model}`") | |
model = getattr(import_module("df_local." + model), "init_model")(*args, **kwargs) | |
model.to(memory_format=torch.channels_last) | |
return model | |