import os import cv2 import torch import albumentations as A import config as CFG class PoemTextDataset(torch.utils.data.Dataset): """ torch Dataset for PoemTextModel. ... Attributes: ----------- dataset_dict : list of dict dataset containing poem-text pair with ids encoded_poems : dict output of tokenizer for beyts found in dataset_dict. max_length spedified in configs. padding and truncation set to True to be truncated or padded to max length. encoded_texts : dict output of tokenizer for texts found in dataset_dict. max_length spedified in configs. padding and truncation set to True to be truncated or padded to max length. Methods: -------- __get_item__(idx) returns item with index idx. __len__() represents length of dataset """ def __init__(self, dataset_dict): """ Init class, save dataset_dict and calculate output of tokenizers for each text and poem using their corresponding tokenizers. The tokenizers are chosen based on configs. Parameters: ----------- dataset_dict: list of dict a list containing dictionaries which have "beyt", "text" and "id" keys. """ self.dataset_dict = dataset_dict poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer) text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer) self.encoded_poems = poem_tokenizer( [item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length ) self.encoded_texts = text_tokenizer( [item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length ) def __getitem__(self, idx): """ returns a dict having data with index idx. the dict is used as an input to the PoemTextModel. Parameters: ----------- idx: int index of the data to get Returns: -------- item: dict a dict having tokenizers' output for poem and text, and id of the data with index idx """ item = {} item["beyt"] = { key: torch.tensor(values[idx]) for key, values in self.encoded_poems.items() } item["text"] = { key: torch.tensor(values[idx]) for key, values in self.encoded_texts.items() } item['id'] = self.dataset_dict[idx]['id'] return item def __len__(self): """ returns the length of the dataset Returns: -------- length: int length using the length of dataset_dict we saved in class """ return len(self.dataset_dict) class CLIPDataset(torch.utils.data.Dataset): """ torch Dataset for CLIPModel. ... Attributes: ----------- dataset_dict : list of dict dataset containing poem-image or text-image pair with ids encoded : dict output of tokenizer for beyts/texts found in dataset_dict. max_length spedified in configs. padding and truncation set to True to be truncated or padded to max length. transforms: albumentations.BasicTransform transforms to apply to the images Methods: -------- __get_item__(idx) returns item with index idx. __len__() represents length of dataset """ def __init__(self, dataset_dict, transforms, is_image_poem_pair=True): """ Init class, save dataset_dict and transforms and calculate output of tokenizers for each text and poem using their corresponding tokenizers. The tokenizers are chosen based on configs. Parameters: ----------- dataset_dict: list of dict a list containing dictionaries which have "beyt", "text" and "id" keys. transforms: albumentations.BasicTransform transforms to apply to the images is_image_poem_pair: Bool, optional if set to False, dataset has text-image pairs and must use the corresponding text tokenizer. else has poem-images pairs and uses the poem tokenizer. """ self.dataset_dict = dataset_dict # using the poem tokenizer to encode poems or text tokenizer to encode text (based on configs). if is_image_poem_pair: poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer) self.encoded = poem_tokenizer( [item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length ) else: text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer) self.encoded = text_tokenizer( [item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length ) self.transforms = transforms def __getitem__(self, idx): """ returns a dict having data with index idx. the dict is used as an input to the CLIPModel. Parameters: ----------- idx: int index of the data to get Returns: -------- item: dict a dict having tokenizers' output for poem and text, and id of the data with index idx """ item = {} # getting text from encoded texts item["text"] = { key: torch.tensor(values[idx]) for key, values in self.encoded.items() } # opening the image image = cv2.imread(f"{CFG.image_path}{self.dataset_dict[idx]['image']}") # converting BGR to RGB for transforms image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # apply transforms image = self.transforms(image=image)['image'] # permute dims of image item['image'] = torch.tensor(image).permute(2, 0, 1).float() return item def __len__(self): """ returns the length of the dataset Returns: -------- length: int length using the length of dataset_dict we saved in class """ return len(self.dataset_dict) def get_transforms(mode="train"): """ returns transforms to use on image based on mode Parameters: ----------- mode: str, optional to distinguish between train and val/test transforms (here they are the same!) Returns: -------- item: dict a dict having tokenizers' output for poem and text, and id of the data with index idx """ if mode == "train": return A.Compose( [ A.Resize(CFG.size, CFG.size, always_apply=True), # resizing image to CFG.size A.Normalize(max_pixel_value=255.0, always_apply=True), # normalizing image values ] ) else: return A.Compose( [ A.Resize(CFG.size, CFG.size, always_apply=True), # resizing image to CFG.size A.Normalize(max_pixel_value=255.0, always_apply=True), # normalizing image values ] )