|
|
import torch
|
|
|
import matplotlib
|
|
|
import matplotlib.pyplot as plt
|
|
|
import os
|
|
|
matplotlib.style.use('ggplot')
|
|
|
class SaveBestModel:
|
|
|
"""
|
|
|
Class to save the best model while training. If the current epoch's
|
|
|
validation loss is less than the previous least less, then save the
|
|
|
model state.
|
|
|
"""
|
|
|
def __init__(
|
|
|
self, best_valid_loss=float('inf')
|
|
|
):
|
|
|
self.best_valid_loss = best_valid_loss
|
|
|
|
|
|
def __call__(
|
|
|
self, current_valid_loss, epoch, model, out_dir, name
|
|
|
):
|
|
|
if current_valid_loss < self.best_valid_loss:
|
|
|
self.best_valid_loss = current_valid_loss
|
|
|
print(f"\nBest validation loss: {self.best_valid_loss}")
|
|
|
print(f"\nSaving best model for epoch: {epoch+1}\n")
|
|
|
torch.save({
|
|
|
'epoch': epoch+1,
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
}, os.path.join(out_dir, 'best_'+name+'.pth'))
|
|
|
def save_model(epochs, model, optimizer, criterion, out_dir, name):
|
|
|
"""
|
|
|
Function to save the trained model to disk.
|
|
|
"""
|
|
|
torch.save({
|
|
|
'epoch': epochs,
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
'loss': criterion,
|
|
|
}, os.path.join(out_dir, name+'.pth'))
|
|
|
def save_plots(train_acc, valid_acc, train_loss, valid_loss, out_dir):
|
|
|
"""
|
|
|
Function to save the loss and accuracy plots to disk.
|
|
|
"""
|
|
|
|
|
|
plt.figure(figsize=(10, 7))
|
|
|
plt.plot(
|
|
|
train_acc, color='tab:blue', linestyle='-',
|
|
|
label='train accuracy'
|
|
|
)
|
|
|
plt.plot(
|
|
|
valid_acc, color='tab:red', linestyle='-',
|
|
|
label='validataion accuracy'
|
|
|
)
|
|
|
plt.xlabel('Epochs')
|
|
|
plt.ylabel('Accuracy')
|
|
|
plt.legend()
|
|
|
plt.savefig(os.path.join(out_dir, 'accuracy.png'))
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 7))
|
|
|
plt.plot(
|
|
|
train_loss, color='tab:blue', linestyle='-',
|
|
|
label='train loss'
|
|
|
)
|
|
|
plt.plot(
|
|
|
valid_loss, color='tab:red', linestyle='-',
|
|
|
label='validataion loss'
|
|
|
)
|
|
|
plt.xlabel('Epochs')
|
|
|
plt.ylabel('Loss')
|
|
|
plt.legend()
|
|
|
plt.savefig(os.path.join(out_dir, 'loss.png')) |