File size: 7,689 Bytes
1bc9b9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import os
import cv2
import torch
import albumentations as A
import config as CFG


class PoemTextDataset(torch.utils.data.Dataset):
    """
    torch Dataset for PoemTextModel.
    ...
    Attributes:
    -----------
    dataset_dict : list of dict
        dataset containing poem-text pair with ids
    encoded_poems : dict
        output of tokenizer for beyts found in dataset_dict. max_length spedified in configs. 
        padding and truncation set to True to be truncated or padded to max length.
    encoded_texts : dict
        output of tokenizer for texts found in dataset_dict. max_length spedified in configs. 
        padding and truncation set to True to be truncated or padded to max length.
    
    Methods:
    --------
        __get_item__(idx)
            returns item with index idx.
        __len__()
            represents length of dataset
    """
    def __init__(self, dataset_dict):
        """
        Init class, save dataset_dict and calculate output of tokenizers for each text and poem using their corresponding tokenizers.
        The tokenizers are chosen based on configs.

            Parameters:
            -----------
                dataset_dict: list of dict
                    a list containing dictionaries which have "beyt", "text" and "id" keys.
        """
        self.dataset_dict = dataset_dict
        poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer)
        text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)
        self.encoded_poems = poem_tokenizer(
            [item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length
        )
        self.encoded_texts = text_tokenizer(
            [item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length
        )

    def __getitem__(self, idx):
        """
        returns a dict having data with index idx. the dict is used as an input to the PoemTextModel.

            Parameters:
            -----------
                idx: int
                    index of the data to get
            
            Returns:
            --------
                item: dict
                    a dict having tokenizers' output for poem and text, and id of the data with index idx
        """
        item = {}
        item["beyt"] = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_poems.items()
        }

        item["text"] = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_texts.items()
        }
        item['id'] = self.dataset_dict[idx]['id']

        return item


    def __len__(self):
        """
        returns the length of the dataset
            
            Returns:
            --------
                length: int
                    length using the length of dataset_dict we saved in class
        """
        return len(self.dataset_dict)


class CLIPDataset(torch.utils.data.Dataset):
    """
    torch Dataset for CLIPModel.
    ...
    Attributes:
    -----------
    dataset_dict : list of dict
        dataset containing poem-image or text-image pair with ids
    encoded : dict
        output of tokenizer for beyts/texts found in dataset_dict. max_length spedified in configs. 
        padding and truncation set to True to be truncated or padded to max length.
    transforms: albumentations.BasicTransform
        transforms to apply to the images
    
    Methods:
    --------
        __get_item__(idx)
            returns item with index idx.
        __len__()
            represents length of dataset
    """
    def __init__(self, dataset_dict, transforms, is_image_poem_pair=True):
        """
        Init class, save dataset_dict and transforms and calculate output of tokenizers for each text and poem using their corresponding tokenizers.
        The tokenizers are chosen based on configs.

            Parameters:
            -----------
                dataset_dict: list of dict
                    a list containing dictionaries which have "beyt", "text" and "id" keys.
                transforms: albumentations.BasicTransform
                    transforms to apply to the images
                is_image_poem_pair: Bool, optional
                    if set to False, dataset has text-image pairs and must use the corresponding text tokenizer.
                    else has poem-images pairs and uses the poem tokenizer.
        """
        self.dataset_dict = dataset_dict
        # using the poem tokenizer to encode poems or text tokenizer to encode text (based on configs).
        if is_image_poem_pair:
          poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer)
          self.encoded = poem_tokenizer(
              [item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length
          )
        else:
          text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)
          self.encoded = text_tokenizer(
              [item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length
          )
        self.transforms = transforms

    def __getitem__(self, idx):
        """
        returns a dict having data with index idx. the dict is used as an input to the CLIPModel.

            Parameters:
            -----------
                idx: int
                    index of the data to get
            
            Returns:
            --------
                item: dict
                    a dict having tokenizers' output for poem and text, and id of the data with index idx
        """
        item = {}
        # getting text from encoded texts
        item["text"] = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded.items()
        }

        # opening the image
        image = cv2.imread(f"{CFG.image_path}{self.dataset_dict[idx]['image']}")
        # converting BGR to RGB for transforms
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # apply transforms
        image = self.transforms(image=image)['image']
        # permute dims of image
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()

        return item


    def __len__(self):
        """
        returns the length of the dataset
            
            Returns:
            --------
                length: int
                    length using the length of dataset_dict we saved in class
        """
        return len(self.dataset_dict)



def get_transforms(mode="train"):
    """
        returns transforms to use on image based on mode

            Parameters:
            -----------
                mode: str, optional
                    to distinguish between train and val/test transforms (here they are the same!)
            
            Returns:
            --------
                item: dict
                    a dict having tokenizers' output for poem and text, and id of the data with index idx
        """
    if mode == "train":
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),  # resizing image to CFG.size
                A.Normalize(max_pixel_value=255.0, always_apply=True),  # normalizing image values
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),  # resizing image to CFG.size
                A.Normalize(max_pixel_value=255.0, always_apply=True),  # normalizing image values
            ]
        )