|
import torch |
|
from model import GLiNER |
|
|
|
|
|
def save_model(current_model, path): |
|
config = current_model.config |
|
dict_save = {"model_weights": current_model.state_dict(), "config": config} |
|
torch.save(dict_save, path) |
|
|
|
|
|
def load_model(path, model_name=None, device=None): |
|
dict_load = torch.load(path, map_location=torch.device('cpu')) |
|
config = dict_load["config"] |
|
|
|
if model_name is not None: |
|
config.model_name = model_name |
|
|
|
loaded_model = GLiNER(config) |
|
loaded_model.load_state_dict(dict_load["model_weights"]) |
|
return loaded_model.to(device) if device is not None else loaded_model |
|
|