File size: 1,447 Bytes
154ca7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch

# Save and Load Functions

def save_checkpoint(save_path, model, valid_loss):
    if save_path == None:
        return
    state_dict = {'model_state_dict': model.state_dict(),
                  'valid_loss': valid_loss}
    torch.save(state_dict, save_path)
    print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path))

def load_checkpoint(load_path, model, device):
    if load_path == None:
        return
    state_dict = torch.load(load_path, map_location=device)
    print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
    model.load_state_dict(state_dict['model_state_dict'])
    return state_dict['valid_loss']

def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
    if save_path == None:
        return
    state_dict = {'train_loss_list': train_loss_list,
                  'valid_loss_list': valid_loss_list,
                  'global_steps_list': global_steps_list}
    torch.save(state_dict, save_path)
    print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path))

def load_metrics(load_path, device):
    if load_path == None:
        return
    state_dict = torch.load(load_path, map_location=device)
    print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path))
    return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']