File size: 4,758 Bytes
22ca2be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset, Dataset
from transformers import ViTFeatureExtractor
from transformers import ViTImageProcessor
from io import BytesIO
from base64 import b64decode
from PIL import Image,ImageFile
import base64
import itertools
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import io
import urllib
import random
import PIL.Image
from datasets import load_dataset
from datasets.utils.file_utils import get_datasets_user_agent


USER_AGENT = get_datasets_user_agent()

# import model
model_id = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(
    model_id
)
class BilingualDataset(Dataset):

    def __init__(self, ds,tokenizer_tgt, seq_len):
        super().__init__()
        self.seq_len = seq_len
        ImageFile.LOAD_TRUNCATED_IMAGES = True
        self.ds = ds
        self.tokenizer_tgt = tokenizer_tgt
        # self.tgt_lang = tgt_lang
        self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
        self.sos_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)
    # def __getitem__(self):
    #     pass
  
    def __getitem__(self, idx):
        data_pair = self.ds[idx]

           
        src_image = data_pair['image_base64_str']
        tgt_text = data_pair['outputs']
       
       


            

        src_image  = Image.open(BytesIO(b64decode(''.join(src_image))))
        if src_image.mode != 'RGB':
            src_image = src_image.convert('RGB')
        src_image = self.processor(src_image, return_tensors='pt')

   
   

                # Transform the text into tokens
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text)

                # # Add sos, eos and padding to each sentence
                # enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2  # We will add <s> and </s>
                # We will only add <s>, and </s> only on the label
        dec_input_tokens = dec_input_tokens[:self.seq_len-1]
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) -1

                # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

                # # Add <s> and </s> token
                # encoder_input = torch.cat(
                #     [
                #         self.sos_token,
                #         torch.tensor(enc_input_tokens, dtype=torch.int64),
                #         self.eos_token,
                #         torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
                #     ],
                #     dim=0,
                # )

                # Add only <s> token
        decoder_input = torch.cat(
                    [
                        self.sos_token,
                        torch.tensor(dec_input_tokens, dtype=torch.int64),
                        torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
                    ],
                    dim=0,
                )

                # Add only </s> token
        label = torch.cat(
                    [
                        torch.tensor(dec_input_tokens, dtype=torch.int64),
                        self.eos_token,
                        torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
                    ],
                    dim=0,
                )


        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len
           
        return {     
        'encoder_input' : src_image['pixel_values'].squeeze(0).squeeze(0).squeeze(0).squeeze(0).squeeze(0),  # (seq_len)
        'decoder_input' : decoder_input,  # (seq_len)
        ## encoder mask not used :)
        "encoder_mask" : (torch.cat((torch.ones(197,),torch.zeros(63),),)).unsqueeze(0).unsqueeze(0), # (1, 1, seq_len)
        "decoder_mask" : (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
        "label" : label,  
                    # "src_text": src_text,
        "tgt_text" : tgt_text

        }
        # yield encoder_input, decoder_input, encoder_mask, decoder_mask, label


def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0