import os import gc import numpy as np import pandas as pd from tqdm import tqdm import random import json import torch from torch import nn #FIX import config as CFG from models import CLIPModel from utils import AvgMeter, get_lr from utils import get_datasets, build_loaders def train_epoch(model, train_loader, optimizer, lr_scheduler, step): """ Performs one epoch of training. Parameters: ----------- model: PoemTextModel or CLIPModel model to train train_loader: torch.utils.data.DataLoader dataloader to get batches from optimizer: torch.optim.Optimizer optimizer used for training lr_scheduler: torch.optim.lr_scheduler.LRScheduler scheduler used for training step: str ("batch" or "epoch") if "batch", lr_scheduler will step (update) for each batch of loader. else lr_scheduler only steps and updates after finishing each epoch. Returns: -------- loss_meter: AvgMeter the class containing average loss of this epoch's training """ loss_meter = AvgMeter() # to track average of loss tqdm_object = tqdm(train_loader, total=len(train_loader)) for batch_cpu in tqdm_object: # put batch data on device batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]} if "image" in batch_cpu: batch["image"] = batch_cpu["image"].to(CFG.device) #get model's embeddings and calculate loss poem_or_img_embeddings, text_embeddings = model(batch) loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings) # backpropagate and step optimizer.zero_grad() loss.backward() optimizer.step() if step == "batch": lr_scheduler.step() #update training info count = batch["text"]["input_ids"].size(0) loss_meter.update(loss.item(), count) tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) # print('train loss: ', loss_meter.avg) return loss_meter def valid_epoch(model, valid_loader): """ Performs one epoch of validation. Parameters: ----------- model: PoemTextModel or CLIPModel model to validate valid_loader: torch.utils.data.DataLoader dataloader to get batches from. Returns: -------- loss_meter: AvgMeter the class containing average loss of this epoch's validation """ loss_meter = AvgMeter() # to track average of loss tqdm_object = tqdm(valid_loader, total=len(valid_loader)) for batch_cpu in tqdm_object: # put batch data on device batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]} if "image" in batch_cpu: batch["image"] = batch_cpu["image"].to(CFG.device) #get model's embeddings and calculate loss poem_or_img_embeddings, text_embeddings = model(batch) loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings) #update validation info count = batch["text"]["input_ids"].size(0) loss_meter.update(loss.item(), count) tqdm_object.set_postfix(valid_loss=loss_meter.avg) # print('validation loss: ', loss_meter.avg) return loss_meter def test(model, test_dataset): """ Calculates accuracy on test set. This method is used for the PoemTextModel, since the other model (CLIPModel) does not have a test set containing pairs of image-poem. Parameters: ----------- model: PoemTextModel model to test test_dataset: list of dict the list containing dict of data to perform test on (must have "text" and "poem" keys) Returns: -------- accuracy: np.float The accuracy of model on the test set given """ test_loader = build_loaders(test_dataset, mode="test") accuracy = 0 tqdm_object = tqdm(test_loader, total=len(test_loader)) model.eval() with torch.no_grad(): for batch_cpu in tqdm_object: # put batch data on device batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]} if "image" in batch_cpu: batch["image"] = batch_cpu["image"].to(CFG.device) # get model's prediction for each text (a numpy array of index/labels showing which poem belongs to which text) pred = model.predict(batch).cpu().numpy() count = batch["text"]["input_ids"].size(0) # since each text is associated with the poem with the same index as it, np.arange(count) is the real labels. acc = np.sum(pred == np.arange(count)) accuracy += acc tqdm_object.set_postfix(accuracy=acc / count) accuracy /= len(test_dataset) return accuracy def train(model, train_loader, valid_loader, epochs=CFG.epochs): """ Performs train and validation for (epochs) epochs. Parameters: ----------- model: PoemTextModel or CLIPModel model to train train_loader: torch.utils.data.DataLoader train dataloader to get batches from valid_loader: torch.utils.data.DataLoader validation dataloader to get batches from epochs: int, optional the number of epochs to train Returns: -------- model: PoemTextModel or CLIPModel trained model loss_history: dict a dict containing train and validation average loss for each epoch. """ # Using AdamW optimizer and ReduceLROnPlateau lr-scheduler with settings from config optimizer = torch.optim.AdamW( model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay ) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", patience=CFG.patience, factor=CFG.factor ) # if step="batch", lr_scheduler will step (update) for each batch of loader. # else lr_scheduler only steps and updates after finishing each epoch. (this case) step = "epoch" loss_history = {"train":[], "valid":[]} # to keep track of best validation loss best_loss = float('inf') for epoch in range(CFG.epochs): print(f"Epoch: {epoch + 1}") # train for one epoch model.train() train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step) loss_history["train"].append(train_loss.avg) # validate trained model model.eval() with torch.no_grad(): valid_loss = valid_epoch(model, valid_loader) loss_history["valid"].append(valid_loss.avg) # if this epoch's avg validation loss is lower than best loss, save and keep this model. if valid_loss.avg < best_loss: best_loss = valid_loss.avg model.save_current() print("Saved Best Model!") if step == "epoch": lr_scheduler.step(valid_loss.avg) return model, loss_history