import sys import os.path import torch code_path = os.path.dirname(os.path.abspath(__file__)) + '/' sys.path.append(code_path) import yaml from ml_collections import ConfigDict torch.set_float32_matmul_precision("medium") def get_model( config_path, weights_path, device, ): from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple f = open(config_path) config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) f.close() model = MultiMaskMultiSourceBandSplitRNNSimple( **config.model ) d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt') model.load_state_dict(d) model.to(device) return model, config