File size: 734 Bytes
51e2f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import sys
import os.path
import torch

code_path = os.path.dirname(os.path.abspath(__file__)) + '/'
sys.path.append(code_path)

import yaml
from ml_collections import ConfigDict

torch.set_float32_matmul_precision("medium")


def get_model(

    config_path,

    weights_path,

    device,

):
    from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple

    f = open(config_path)
    config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
    f.close()

    model = MultiMaskMultiSourceBandSplitRNNSimple(
        **config.model
    )
    d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt')
    model.load_state_dict(d)
    model.to(device)
    return model, config