mojtaba-nafez's picture
add initial files to deploy
2fa2727
raw
history blame
7.75 kB
import config as CFG
import json
from models import PoemTextModel
import torch
import random
from datasets import PoemTextDataset, get_transforms, CLIPDataset
from tqdm import tqdm
import numpy as np
class AvgMeter:
"""
Used to keep track of batch losses during training / validation.
...
Attributes:
-----------
name : str
count : int
number of data whose train/val loss has been metered
sum: int or float
sum of all losses metered
avg: int or float
average of metered losses
Methods:
--------
reset():
Sets count, sum and avg to 0.
update(val, count=1):
Updates loss sum, count and avg.
__repr__():
string representation of this class.
"""
def __init__(self, name="Metric"):
"""Sets the name of the avg meter. sets avg, sum & count to 0."""
self.name = name
self.reset()
def reset(self):
"""Sets avg, sum & count to 0."""
self.avg, self.sum, self.count = [0] * 3
def update(self, val, count=1):
"""Updates loss sum, count and avg using val and count (count of the val input)"""
self.count += count
self.sum += val * count
self.avg = self.sum / self.count
def __repr__(self):
"""String representation of this class"""
text = f"{self.name}: {self.avg:.4f}"
return text
def get_lr(optimizer):
"""Returns learning rate of the input optimizer"""
for param_group in optimizer.param_groups:
return param_group["lr"]
def get_datasets():
"""
Returns train, validation & test split from a dataset json file specified using CFG.dataset_path.
This function first loads the file into a list of dict and shuffles them with CFG.random_seed seed,
then splits them using CFG.train_propotion & CFG.val_propotion.
Returns:
--------
train_dataset: list of dict
Train split
val_dataset: list of dict
Validation split
test_dataset: list of dict
Test split
"""
with open(CFG.dataset_path, encoding="utf-8") as f:
dataset = json.load(f)
random.Random(CFG.random_seed).shuffle(dataset)
# https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test
train_dataset, val_dataset, test_dataset = np.split(dataset,
[int(CFG.train_propotion*len(dataset)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset))])
return train_dataset, val_dataset, test_dataset
def build_loaders(dataset_dict, mode):
"""
Returns a torch Dataloader from a list of dictionaries (dataset_dict).
First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader.
Parameters:
-----------
dataset_dict: list of dict
the dataset to return a dataloader of.
mode: str ("train" or any other word)
if the mode is "train", dataloader will activate shuffling.
Returns:
--------
dataloader: torch.utils.data.DataLoader
the torch Dataloader created from dataset_dict using PoemTextDataset and configs.
"""
dataset = PoemTextDataset(
dataset_dict
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
def get_clip_datasets(dataset_dict):
"""
(Used for clip model training) Returns train, validation & test split from input.
This function takes a list of dict as dataset and shuffles them with CFG.random_seed seed,
then splits them using CFG.train_propotion & CFG.val_propotion.
Parameters:
-----------
dataset_dict: list of dict
the input dataset
Returns:
--------
train_dataset: list of dict
Train split
val_dataset: list of dict
Validation split
test_dataset: list of dict
Test split
"""
random.Random(CFG.random_seed).shuffle(dataset_dict)
# https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test
train_dataset, val_dataset, test_dataset = np.split(dataset_dict,
[int(CFG.train_propotion*len(dataset_dict)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset_dict))])
return train_dataset, val_dataset, test_dataset
def build_image_loaders(dataset_dict, mode):
"""
(Used for clip model training) Returns a torch Dataloader from a list of dictionaries (dataset_dict).
First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader.
Parameters:
-----------
dataset_dict: list of dict
the dataset to return a dataloader of.
mode: str ("train" or any other word)
if the mode is "train", dataloader will activate shuffling.
Returns:
--------
dataloader: torch.utils.data.DataLoader
the torch Dataloader created from dataset_dict using CLIPDataset and configs.
"""
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataset_dict, transforms, is_image_poem_pair=False
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
def get_poem_embeddings(test_dataset, model=None):
"""
Returns embeddings of the poems existing in test_dataset.
Parameters:
-----------
test_dataset: list of dict
dataset to get poems from. each of its dictionaries must have a "beyt" key.
model: PoemTextModel, optional
The PoemTextModel model to get poem embeddings from.
If None is given, instantiates a new model (with all of its parts in pretrained settings) using configurations provided in config.py.
Returns:
--------
model (PoemTextModel): The model used for creating poem embeddings
"""
test_loader = build_loaders(test_dataset, mode="test") # building a dataloder (which also tokenizes the poems)
if model == None:
model = PoemTextModel(True, False, True, False, poem_projection_pretrained=True, text_projection_pretrained=True).to(CFG.device)
model.eval()
poem_embeddings = []
with torch.no_grad():
for batch in tqdm(test_loader):
# get poem embeddings by passing tokenizer output of the poems
# to the model's poem encoder and projection
beyts = {
key: values.to(CFG.device)
for key, values in batch["beyt"].items()
}
if model.__class__.__name__ == "PoemTextModel":
poem_features = model.poem_encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"])
poem_emb = model.poem_projection(poem_features)
poem_embeddings.append(poem_emb)
elif model.__class__.__name__ == "CLIPModel":
poem_features = model.encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"])
poem_emb = model.text_projection(poem_features)
poem_embeddings.append(poem_emb)
else:
raise #not a right model to use!
return model, torch.cat(poem_embeddings)