import config as CFG import json from models import PoemTextModel import torch import random from datasets import PoemTextDataset, get_transforms, CLIPDataset from tqdm import tqdm import numpy as np class AvgMeter: """ Used to keep track of batch losses during training / validation. ... Attributes: ----------- name : str count : int number of data whose train/val loss has been metered sum: int or float sum of all losses metered avg: int or float average of metered losses Methods: -------- reset(): Sets count, sum and avg to 0. update(val, count=1): Updates loss sum, count and avg. __repr__(): string representation of this class. """ def __init__(self, name="Metric"): """Sets the name of the avg meter. sets avg, sum & count to 0.""" self.name = name self.reset() def reset(self): """Sets avg, sum & count to 0.""" self.avg, self.sum, self.count = [0] * 3 def update(self, val, count=1): """Updates loss sum, count and avg using val and count (count of the val input)""" self.count += count self.sum += val * count self.avg = self.sum / self.count def __repr__(self): """String representation of this class""" text = f"{self.name}: {self.avg:.4f}" return text def get_lr(optimizer): """Returns learning rate of the input optimizer""" for param_group in optimizer.param_groups: return param_group["lr"] def get_datasets(): """ Returns train, validation & test split from a dataset json file specified using CFG.dataset_path. This function first loads the file into a list of dict and shuffles them with CFG.random_seed seed, then splits them using CFG.train_propotion & CFG.val_propotion. Returns: -------- train_dataset: list of dict Train split val_dataset: list of dict Validation split test_dataset: list of dict Test split """ with open(CFG.dataset_path, encoding="utf-8") as f: dataset = json.load(f) random.Random(CFG.random_seed).shuffle(dataset) # https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test train_dataset, val_dataset, test_dataset = np.split(dataset, [int(CFG.train_propotion*len(dataset)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset))]) return train_dataset, val_dataset, test_dataset def build_loaders(dataset_dict, mode): """ Returns a torch Dataloader from a list of dictionaries (dataset_dict). First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader. Parameters: ----------- dataset_dict: list of dict the dataset to return a dataloader of. mode: str ("train" or any other word) if the mode is "train", dataloader will activate shuffling. Returns: -------- dataloader: torch.utils.data.DataLoader the torch Dataloader created from dataset_dict using PoemTextDataset and configs. """ dataset = PoemTextDataset( dataset_dict ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=CFG.batch_size, num_workers=CFG.num_workers, shuffle=True if mode == "train" else False, ) return dataloader def get_clip_datasets(dataset_dict): """ (Used for clip model training) Returns train, validation & test split from input. This function takes a list of dict as dataset and shuffles them with CFG.random_seed seed, then splits them using CFG.train_propotion & CFG.val_propotion. Parameters: ----------- dataset_dict: list of dict the input dataset Returns: -------- train_dataset: list of dict Train split val_dataset: list of dict Validation split test_dataset: list of dict Test split """ random.Random(CFG.random_seed).shuffle(dataset_dict) # https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test train_dataset, val_dataset, test_dataset = np.split(dataset_dict, [int(CFG.train_propotion*len(dataset_dict)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset_dict))]) return train_dataset, val_dataset, test_dataset def build_image_loaders(dataset_dict, mode): """ (Used for clip model training) Returns a torch Dataloader from a list of dictionaries (dataset_dict). First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader. Parameters: ----------- dataset_dict: list of dict the dataset to return a dataloader of. mode: str ("train" or any other word) if the mode is "train", dataloader will activate shuffling. Returns: -------- dataloader: torch.utils.data.DataLoader the torch Dataloader created from dataset_dict using CLIPDataset and configs. """ transforms = get_transforms(mode=mode) dataset = CLIPDataset( dataset_dict, transforms, is_image_poem_pair=False ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=CFG.batch_size, num_workers=CFG.num_workers, shuffle=True if mode == "train" else False, ) return dataloader def get_poem_embeddings(test_dataset, model=None): """ Returns embeddings of the poems existing in test_dataset. Parameters: ----------- test_dataset: list of dict dataset to get poems from. each of its dictionaries must have a "beyt" key. model: PoemTextModel, optional The PoemTextModel model to get poem embeddings from. If None is given, instantiates a new model (with all of its parts in pretrained settings) using configurations provided in config.py. Returns: -------- model (PoemTextModel): The model used for creating poem embeddings """ test_loader = build_loaders(test_dataset, mode="test") # building a dataloder (which also tokenizes the poems) if model == None: model = PoemTextModel(True, False, True, False, poem_projection_pretrained=True, text_projection_pretrained=True).to(CFG.device) model.eval() poem_embeddings = [] with torch.no_grad(): for batch in tqdm(test_loader): # get poem embeddings by passing tokenizer output of the poems # to the model's poem encoder and projection beyts = { key: values.to(CFG.device) for key, values in batch["beyt"].items() } if model.__class__.__name__ == "PoemTextModel": poem_features = model.poem_encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"]) poem_emb = model.poem_projection(poem_features) poem_embeddings.append(poem_emb) elif model.__class__.__name__ == "CLIPModel": poem_features = model.encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"]) poem_emb = model.text_projection(poem_features) poem_embeddings.append(poem_emb) else: raise #not a right model to use! return model, torch.cat(poem_embeddings)