import torch from .helper_functions import define_optimizer, predict, display_train, eval_test from tqdm import tqdm import matplotlib.pyplot as plt def save_model(model, optimizer, valid_loss, epoch, path='model.pt'): torch.save({'valid_loss': valid_loss, 'model_state_dict': model.state_dict(), 'epoch': epoch + 1, 'optimizer': optimizer.state_dict() }, path) tqdm.write(f'Model saved to ==> {path}') def save_metrics(train_loss_list, valid_loss_list, global_steps_list, path='metrics.pt'): torch.save({'train_loss_list': train_loss_list, 'valid_loss_list': valid_loss_list, 'global_steps_list': global_steps_list, }, path) def plot_losses(metrics_save_name='metrics', save_dir='./'): path = f'{save_dir}metrics_{metrics_save_name}.pt' state = torch.load(path) train_loss_list = state['train_loss_list'] valid_loss_list = state['valid_loss_list'] global_steps_list = state['global_steps_list'] plt.plot(global_steps_list, train_loss_list, label='Train') plt.plot(global_steps_list, valid_loss_list, label='Valid') plt.xlabel('Global Steps') plt.ylabel('Loss') plt.legend() plt.show() def trainer(model, train_loader, test_loader, valid_loader, num_epochs = 10, lr = 0.01, alpha = 0.99, eval_interval = 10, model_save_name='', save_dir='./'): # Use GPU if available, else use CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) # History for train acc, test acc train_accs = [] valid_accs = [] global_step = 0 train_loss_list = [] valid_loss_list = [] global_steps_list = [] best_valid_loss = float("inf") # Define optimizer optimizer = define_optimizer(model, lr, alpha) # Training model for epoch in range(num_epochs): # Go trough all samples in train dataset model.train() running_loss = 0 correct = 0 total = 0 for i, (inputs, labels, notes) in enumerate(train_loader): # Get from dataloader and send to device inputs = inputs.transpose(1,2).float().to(device) # print(labels.shape) labels = labels.float().to(device) notes = notes.to(device) # print(labels.shape) # Forward pass outputs, predicted = predict(model, inputs, notes, device) # print(predicted.shape, labels.shape) # Check if predicted class matches label and count numbler of correct predictions total += labels.size(0) #TODO: change acc criteria # correct += torch.nn.functional.cosine_similarity(labels,predicted).sum().item() #(predicted == labels).sum().item() values, indices = torch.max(outputs,dim=1) correct += sum(1 for s, i in enumerate(indices) if labels[s][i] == 1) # Compute loss # we use outputs before softmax function to the cross_entropy loss loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, labels) running_loss += loss.item()*len(labels) global_step += 1*len(inputs) # Backward and optimize loss.backward() optimizer.step() optimizer.zero_grad() # Display losses over iterations and evaluate on validation set if (i+1) % eval_interval == 0: train_accuracy, valid_accuracy, valid_loss = display_train(epoch, num_epochs, i, model, \ correct, total, loss, \ train_loader, valid_loader, device) average_train_loss = running_loss / total # average_valid_loss = valid_loss train_loss_list.append(average_train_loss) valid_loss_list.append(valid_loss) global_steps_list.append(global_step) if valid_loss < best_valid_loss: best_valid_loss = valid_loss save_model(model, optimizer, best_valid_loss, epoch, path=f'{save_dir}model_{model_save_name}.pt') save_metrics(train_loss_list, valid_loss_list, global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt') # torch.save(model.state_dict(), f'./ckpt_mid/{model.name}_best_lr_{lr}.pt') if(len(train_loader)%eval_interval!=0): train_accuracy, valid_accuracy, valid_loss = display_train(epoch, num_epochs, i, model, \ correct, total, loss, \ train_loader, valid_loader, device) average_train_loss = running_loss / total # average_valid_loss = valid_loss/len(valid_loader.dataset) train_loss_list.append(average_train_loss) valid_loss_list.append(valid_loss) global_steps_list.append(global_step) if valid_loss < best_valid_loss: best_valid_loss = valid_loss save_model(model, optimizer, best_valid_loss, epoch, path=f'{save_dir}model_{model_save_name}.pt') save_metrics(train_loss_list, valid_loss_list, global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt') # torch.save(model.state_dict(), f'./ckpt_mid/{model.name}_best_lr_{lr}.pt') # Append accuracies to list at the end of each iteration train_accs.append(train_accuracy) valid_accs.append(valid_accuracy) # torch.save(model.state_dict(), f'./ckpt_mid/{model.name}_epoch_{epoch}_lr_{lr}.pt') save_metrics(train_loss_list, valid_loss_list, global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt') # Load best_model checkpoint = torch.load(f'{save_dir}model_{model_save_name}.pt') model.load_state_dict(checkpoint['model_state_dict']) # Evaluate on test after training has completed test_acc = eval_test(model, test_loader, device) # Return return train_accs, valid_accs, test_acc