OcTra / df_local /model.py
arcan3's picture
adding rust
35916c5
raw
history blame
795 Bytes
from importlib import import_module
import torch
from loguru import logger
from df_local.config import DfParams, config
class ModelParams(DfParams):
def __init__(self):
self.__model = config("MODEL", default="deepfilternet", section="train")
self.__params = getattr(import_module("df_local." + self.__model), "ModelParams")()
def __getattr__(self, attr: str):
return getattr(self.__params, attr)
def init_model(*args, **kwargs):
"""Initialize the model specified in the config."""
model = config("MODEL", default="deepfilternet", section="train")
logger.info(f"Initializing model `{model}`")
model = getattr(import_module("df_local." + model), "init_model")(*args, **kwargs)
model.to(memory_format=torch.channels_last)
return model