File size: 619 Bytes
914502f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from model import GLiNER


def save_model(current_model, path):
    config = current_model.config
    dict_save = {"model_weights": current_model.state_dict(), "config": config}
    torch.save(dict_save, path)


def load_model(path, model_name=None, device=None):
    dict_load = torch.load(path, map_location=torch.device('cpu'))
    config = dict_load["config"]

    if model_name is not None:
        config.model_name = model_name

    loaded_model = GLiNER(config)
    loaded_model.load_state_dict(dict_load["model_weights"])
    return loaded_model.to(device) if device is not None else loaded_model