File size: 1,292 Bytes
34fb220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import yaml
import logging

import torch


def parse_configs(config: str):
    """ Parse the config file and return a dictionary of configs

    :param config: path to the config file
    :returns:

    """
    if not os.path.exists(config):
        logging.error('Cannot find the config file: {}'.format(config))
        exit()

    with open(config, 'r') as stream:
        try:
            configs=yaml.safe_load(stream)
            return configs

        except yaml.YAMLError as exc:
            logging.error(exc)
            return {}


def load_model(config: str, weight: str, model_def, device):
    """ Load the model from the config file and the weight file

    :param config: path to the config file
    :param weight: path to the weight file
    :param model_def: model class definition
    :param device: pytorch device
    :returns:

    """
    assert os.path.exists(weight), 'Cannot find the weight file: {}'.format(weight)
    assert os.path.exists(config), 'Cannot find the config file: {}'.format(config)


    opt = parse_configs(config)
    model = model_def(opt)
    cp = torch.load(weight)

    models = model.get_models()
    for k, m in models.items():
        m.load_state_dict(cp[k])
        m.to(device)

    model.set_models(models)
    return model