Spaces:
Runtime error
Runtime error
| from importlib import import_module | |
| import torch | |
| from loguru import logger | |
| from df_local.config import DfParams, config | |
| class ModelParams(DfParams): | |
| def __init__(self): | |
| self.__model = config("MODEL", default="deepfilternet", section="train") | |
| self.__params = getattr(import_module("df_local." + self.__model), "ModelParams")() | |
| def __getattr__(self, attr: str): | |
| return getattr(self.__params, attr) | |
| def init_model(*args, **kwargs): | |
| """Initialize the model specified in the config.""" | |
| model = config("MODEL", default="deepfilternet", section="train") | |
| logger.info(f"Initializing model `{model}`") | |
| model = getattr(import_module("df_local." + model), "init_model")(*args, **kwargs) | |
| model.to(memory_format=torch.channels_last) | |
| return model | |