gliner_multi / GLiNER /save_load.py
Tom Aarsen
Add cloned GLiNER repository
914502f
raw
history blame
619 Bytes
import torch
from model import GLiNER
def save_model(current_model, path):
config = current_model.config
dict_save = {"model_weights": current_model.state_dict(), "config": config}
torch.save(dict_save, path)
def load_model(path, model_name=None, device=None):
dict_load = torch.load(path, map_location=torch.device('cpu'))
config = dict_load["config"]
if model_name is not None:
config.model_name = model_name
loaded_model = GLiNER(config)
loaded_model.load_state_dict(dict_load["model_weights"])
return loaded_model.to(device) if device is not None else loaded_model