mojtaba-nafez's picture
Duplicate from mojtaba-nafez/persian-poem-recommender-based-on-text
1bc9b9d
raw
history blame
7.75 kB
import config as CFG
import json
from models import PoemTextModel
import torch
import random
from datasets import PoemTextDataset, get_transforms, CLIPDataset
from tqdm import tqdm
import numpy as np
class AvgMeter:
"""
Used to keep track of batch losses during training / validation.
...
Attributes:
-----------
name : str
count : int
number of data whose train/val loss has been metered
sum: int or float
sum of all losses metered
avg: int or float
average of metered losses
Methods:
--------
reset():
Sets count, sum and avg to 0.
update(val, count=1):
Updates loss sum, count and avg.
__repr__():
string representation of this class.
"""
def __init__(self, name="Metric"):
"""Sets the name of the avg meter. sets avg, sum & count to 0."""
self.name = name
self.reset()
def reset(self):
"""Sets avg, sum & count to 0."""
self.avg, self.sum, self.count = [0] * 3
def update(self, val, count=1):
"""Updates loss sum, count and avg using val and count (count of the val input)"""
self.count += count
self.sum += val * count
self.avg = self.sum / self.count
def __repr__(self):
"""String representation of this class"""
text = f"{self.name}: {self.avg:.4f}"
return text
def get_lr(optimizer):
"""Returns learning rate of the input optimizer"""
for param_group in optimizer.param_groups:
return param_group["lr"]
def get_datasets():
"""
Returns train, validation & test split from a dataset json file specified using CFG.dataset_path.
This function first loads the file into a list of dict and shuffles them with CFG.random_seed seed,
then splits them using CFG.train_propotion & CFG.val_propotion.
Returns:
--------
train_dataset: list of dict
Train split
val_dataset: list of dict
Validation split
test_dataset: list of dict
Test split
"""
with open(CFG.dataset_path, encoding="utf-8") as f:
dataset = json.load(f)
random.Random(CFG.random_seed).shuffle(dataset)
# https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test
train_dataset, val_dataset, test_dataset = np.split(dataset,
[int(CFG.train_propotion*len(dataset)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset))])
return train_dataset, val_dataset, test_dataset
def build_loaders(dataset_dict, mode):
"""
Returns a torch Dataloader from a list of dictionaries (dataset_dict).
First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader.
Parameters:
-----------
dataset_dict: list of dict
the dataset to return a dataloader of.
mode: str ("train" or any other word)
if the mode is "train", dataloader will activate shuffling.
Returns:
--------
dataloader: torch.utils.data.DataLoader
the torch Dataloader created from dataset_dict using PoemTextDataset and configs.
"""
dataset = PoemTextDataset(
dataset_dict
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
def get_clip_datasets(dataset_dict):
"""
(Used for clip model training) Returns train, validation & test split from input.
This function takes a list of dict as dataset and shuffles them with CFG.random_seed seed,
then splits them using CFG.train_propotion & CFG.val_propotion.
Parameters:
-----------
dataset_dict: list of dict
the input dataset
Returns:
--------
train_dataset: list of dict
Train split
val_dataset: list of dict
Validation split
test_dataset: list of dict
Test split
"""
random.Random(CFG.random_seed).shuffle(dataset_dict)
# https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test
train_dataset, val_dataset, test_dataset = np.split(dataset_dict,
[int(CFG.train_propotion*len(dataset_dict)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset_dict))])
return train_dataset, val_dataset, test_dataset
def build_image_loaders(dataset_dict, mode):
"""
(Used for clip model training) Returns a torch Dataloader from a list of dictionaries (dataset_dict).
First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader.
Parameters:
-----------
dataset_dict: list of dict
the dataset to return a dataloader of.
mode: str ("train" or any other word)
if the mode is "train", dataloader will activate shuffling.
Returns:
--------
dataloader: torch.utils.data.DataLoader
the torch Dataloader created from dataset_dict using CLIPDataset and configs.
"""
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataset_dict, transforms, is_image_poem_pair=False
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
def get_poem_embeddings(test_dataset, model=None):
"""
Returns embeddings of the poems existing in test_dataset.
Parameters:
-----------
test_dataset: list of dict
dataset to get poems from. each of its dictionaries must have a "beyt" key.
model: PoemTextModel, optional
The PoemTextModel model to get poem embeddings from.
If None is given, instantiates a new model (with all of its parts in pretrained settings) using configurations provided in config.py.
Returns:
--------
model (PoemTextModel): The model used for creating poem embeddings
"""
test_loader = build_loaders(test_dataset, mode="test") # building a dataloder (which also tokenizes the poems)
if model == None:
model = PoemTextModel(True, False, True, False, poem_projection_pretrained=True, text_projection_pretrained=True).to(CFG.device)
model.eval()
poem_embeddings = []
with torch.no_grad():
for batch in tqdm(test_loader):
# get poem embeddings by passing tokenizer output of the poems
# to the model's poem encoder and projection
beyts = {
key: values.to(CFG.device)
for key, values in batch["beyt"].items()
}
if model.__class__.__name__ == "PoemTextModel":
poem_features = model.poem_encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"])
poem_emb = model.poem_projection(poem_features)
poem_embeddings.append(poem_emb)
elif model.__class__.__name__ == "CLIPModel":
poem_features = model.encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"])
poem_emb = model.text_projection(poem_features)
poem_embeddings.append(poem_emb)
else:
raise #not a right model to use!
return model, torch.cat(poem_embeddings)