Spaces:
Runtime error
Runtime error
File size: 1,447 Bytes
154ca7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
# Save and Load Functions
def save_checkpoint(save_path, model, valid_loss):
if save_path == None:
return
state_dict = {'model_state_dict': model.state_dict(),
'valid_loss': valid_loss}
torch.save(state_dict, save_path)
print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path))
def load_checkpoint(load_path, model, device):
if load_path == None:
return
state_dict = torch.load(load_path, map_location=device)
print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
model.load_state_dict(state_dict['model_state_dict'])
return state_dict['valid_loss']
def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
if save_path == None:
return
state_dict = {'train_loss_list': train_loss_list,
'valid_loss_list': valid_loss_list,
'global_steps_list': global_steps_list}
torch.save(state_dict, save_path)
print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path))
def load_metrics(load_path, device):
if load_path == None:
return
state_dict = torch.load(load_path, map_location=device)
print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path))
return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list'] |