image-caption / dataset.py
Lwasinam's picture
Upload 5 files
22ca2be verified
raw
history blame contribute delete
No virus
4.76 kB
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset, Dataset
from transformers import ViTFeatureExtractor
from transformers import ViTImageProcessor
from io import BytesIO
from base64 import b64decode
from PIL import Image,ImageFile
import base64
import itertools
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import io
import urllib
import random
import PIL.Image
from datasets import load_dataset
from datasets.utils.file_utils import get_datasets_user_agent
USER_AGENT = get_datasets_user_agent()
# import model
model_id = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(
model_id
)
class BilingualDataset(Dataset):
def __init__(self, ds,tokenizer_tgt, seq_len):
super().__init__()
self.seq_len = seq_len
ImageFile.LOAD_TRUNCATED_IMAGES = True
self.ds = ds
self.tokenizer_tgt = tokenizer_tgt
# self.tgt_lang = tgt_lang
self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
self.sos_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[SOS]")], dtype=torch.int64)
self.eos_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[EOS]")], dtype=torch.int64)
self.pad_token = torch.tensor([tokenizer_tgt.convert_tokens_to_ids("[PAD]")], dtype=torch.int64)
def __len__(self):
return len(self.ds)
# def __getitem__(self):
# pass
def __getitem__(self, idx):
data_pair = self.ds[idx]
src_image = data_pair['image_base64_str']
tgt_text = data_pair['outputs']
src_image = Image.open(BytesIO(b64decode(''.join(src_image))))
if src_image.mode != 'RGB':
src_image = src_image.convert('RGB')
src_image = self.processor(src_image, return_tensors='pt')
# Transform the text into tokens
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text)
# # Add sos, eos and padding to each sentence
# enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s>
# We will only add <s>, and </s> only on the label
dec_input_tokens = dec_input_tokens[:self.seq_len-1]
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) -1
# Make sure the number of padding tokens is not negative. If it is, the sentence is too long
if dec_num_padding_tokens < 0:
raise ValueError("Sentence is too long")
# # Add <s> and </s> token
# encoder_input = torch.cat(
# [
# self.sos_token,
# torch.tensor(enc_input_tokens, dtype=torch.int64),
# self.eos_token,
# torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
# ],
# dim=0,
# )
# Add only <s> token
decoder_input = torch.cat(
[
self.sos_token,
torch.tensor(dec_input_tokens, dtype=torch.int64),
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
# Add only </s> token
label = torch.cat(
[
torch.tensor(dec_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
assert decoder_input.size(0) == self.seq_len
assert label.size(0) == self.seq_len
return {
'encoder_input' : src_image['pixel_values'].squeeze(0).squeeze(0).squeeze(0).squeeze(0).squeeze(0), # (seq_len)
'decoder_input' : decoder_input, # (seq_len)
## encoder mask not used :)
"encoder_mask" : (torch.cat((torch.ones(197,),torch.zeros(63),),)).unsqueeze(0).unsqueeze(0), # (1, 1, seq_len)
"decoder_mask" : (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
"label" : label,
# "src_text": src_text,
"tgt_text" : tgt_text
}
# yield encoder_input, decoder_input, encoder_mask, decoder_mask, label
def causal_mask(size):
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
return mask == 0