import torch import torch.nn as nn from torch.utils.data import IterableDataset, Dataset from transformers import ViTFeatureExtractor from transformers import ViTImageProcessor from io import BytesIO from base64 import b64decode from PIL import Image,ImageFile import base64 import itertools from concurrent.futures import ThreadPoolExecutor from functools import partial import io import urllib import random import PIL.Image from datasets import load_dataset from datasets.utils.file_utils import get_datasets_user_agent USER_AGENT = get_datasets_user_agent() # import model model_id = 'google/vit-base-patch16-224-in21k' feature_extractor = ViTFeatureExtractor.from_pretrained( model_id ) class BilingualDataset(Dataset): def __init__(self, ds,tokenizer_tgt, seq_len): super().__init__() self.seq_len = seq_len ImageFile.LOAD_TRUNCATED_IMAGES = True self.ds = ds self.tokenizer_tgt = tokenizer_tgt # self.tgt_lang = tgt_lang self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') self.sos_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[SOS]")], dtype=torch.int64) self.eos_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[EOS]")], dtype=torch.int64) self.pad_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[PAD]")], dtype=torch.int64) def __len__(self): return len(self.ds) # def __getitem__(self): # pass def __getitem__(self, idx): data_pair = self.ds[idx] src_image = data_pair['image_base64_str'] tgt_text = data_pair['outputs'] src_image = Image.open(BytesIO(b64decode(''.join(src_image)))) if src_image.mode != 'RGB': src_image = src_image.convert('RGB') src_image = self.processor(src_image, return_tensors='pt') # Transform the text into tokens dec_input_tokens = self.tokenizer_tgt.encode(tgt_text) # # Add sos, eos and padding to each sentence # enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # We will add and # We will only add , and only on the label dec_input_tokens = dec_input_tokens[:self.seq_len-1] dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) -1 # Make sure the number of padding tokens is not negative. If it is, the sentence is too long if dec_num_padding_tokens < 0: raise ValueError("Sentence is too long") # # Add and token # encoder_input = torch.cat( # [ # self.sos_token, # torch.tensor(enc_input_tokens, dtype=torch.int64), # self.eos_token, # torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64), # ], # dim=0, # ) # Add only token decoder_input = torch.cat( [ self.sos_token, torch.tensor(dec_input_tokens, dtype=torch.int64), torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64), ], dim=0, ) # Add only token label = torch.cat( [ torch.tensor(dec_input_tokens, dtype=torch.int64), self.eos_token, torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64), ], dim=0, ) assert decoder_input.size(0) == self.seq_len assert label.size(0) == self.seq_len return { 'encoder_input' : src_image['pixel_values'].squeeze(0).squeeze(0).squeeze(0).squeeze(0).squeeze(0), # (seq_len) 'decoder_input' : decoder_input, # (seq_len) ## encoder mask not used :) "encoder_mask" : (torch.cat((torch.ones(197,),torch.zeros(63),),)).unsqueeze(0).unsqueeze(0), # (1, 1, seq_len) "decoder_mask" : (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len), "label" : label, # "src_text": src_text, "tgt_text" : tgt_text } # yield encoder_input, decoder_input, encoder_mask, decoder_mask, label def causal_mask(size): mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int) return mask == 0