|
import random |
|
|
|
import cv2 |
|
import h5py |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
from transformers import CLIPTokenizer |
|
|
|
from models.blip_override.blip import init_tokenizer |
|
|
|
|
|
class StoryDataset(Dataset): |
|
""" |
|
A custom subset class for the LRW (includes train, val, test) subset |
|
""" |
|
|
|
def __init__(self, subset, args): |
|
super(StoryDataset, self).__init__() |
|
self.args = args |
|
|
|
self.h5_file = args.get(args.dataset).hdf5_file |
|
self.subset = subset |
|
|
|
self.augment = transforms.Compose([ |
|
transforms.ToPILImage(), |
|
transforms.Resize([512, 512]), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]) |
|
]) |
|
self.dataset = args.dataset |
|
self.max_length = args.get(args.dataset).max_length |
|
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer") |
|
self.blip_tokenizer = init_tokenizer() |
|
msg = self.clip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens)) |
|
print("clip {} new tokens added".format(msg)) |
|
msg = self.blip_tokenizer.add_tokens(list(args.get(args.dataset).new_tokens)) |
|
print("blip {} new tokens added".format(msg)) |
|
|
|
self.blip_image_processor = transforms.Compose([ |
|
transforms.ToPILImage(), |
|
transforms.Resize([224, 224]), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]) |
|
]) |
|
|
|
def open_h5(self): |
|
h5 = h5py.File(self.h5_file, "r") |
|
self.h5 = h5[self.subset] |
|
|
|
def __getitem__(self, index): |
|
if not hasattr(self, 'h5'): |
|
self.open_h5() |
|
|
|
images = list() |
|
for i in range(5): |
|
im = self.h5['image{}'.format(i)][index] |
|
im = cv2.imdecode(im, cv2.IMREAD_COLOR) |
|
idx = random.randint(0, 4) |
|
images.append(im[idx * 128: (idx + 1) * 128]) |
|
|
|
source_images = torch.stack([self.blip_image_processor(im) for im in images]) |
|
images = images[1:] if self.args.task == 'continuation' else images |
|
images = torch.stack([self.augment(im) for im in images]) \ |
|
if self.subset in ['train', 'val'] else torch.from_numpy(np.array(images)).permute(0, 3, 1, 2) |
|
|
|
texts = self.h5['text'][index].decode('utf-8').split('|') |
|
|
|
|
|
tokenized = self.clip_tokenizer( |
|
texts[1:] if self.args.task == 'continuation' else texts, |
|
padding="max_length", |
|
max_length=self.max_length, |
|
truncation=False, |
|
return_tensors="pt", |
|
) |
|
captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask'] |
|
|
|
tokenized = self.blip_tokenizer( |
|
texts, |
|
padding="max_length", |
|
max_length=self.max_length, |
|
truncation=False, |
|
return_tensors="pt", |
|
) |
|
source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask'] |
|
return images, captions, attention_mask, source_images, source_caption, source_attention_mask |
|
|
|
def __len__(self): |
|
if not hasattr(self, 'h5'): |
|
self.open_h5() |
|
return len(self.h5['text']) |
|
|