Spaces:
Runtime error
Runtime error
import torch | |
from model import GLiNER | |
def save_model(current_model, path): | |
config = current_model.config | |
dict_save = {"model_weights": current_model.state_dict(), "config": config} | |
torch.save(dict_save, path) | |
def load_model(path, model_name=None, device=None): | |
dict_load = torch.load(path, map_location=torch.device('cpu')) | |
config = dict_load["config"] | |
if model_name is not None: | |
config.model_name = model_name | |
loaded_model = GLiNER(config) | |
loaded_model.load_state_dict(dict_load["model_weights"]) | |
return loaded_model.to(device) if device is not None else loaded_model | |