Anyou's picture
Upload 8 files
b2b0303
raw
history blame
3.44 kB
import cv2
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import CLIPTokenizer
from models.blip_override.blip import init_tokenizer
class StoryDataset(Dataset):
"""
A custom subset class for the LRW (includes train, val, test) subset
"""
def __init__(self, subset, args):
super(StoryDataset, self).__init__()
self.args = args
self.h5_file = args.get(args.dataset).hdf5_file
self.subset = subset
self.augment = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(512),
transforms.RandomCrop(512) if self.subset == 'train' else transforms.CenterCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
]) if self.subset in ['train', 'val'] else transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(64),
transforms.CenterCrop(64)
])
self.dataset = args.dataset
self.max_length = args.get(args.dataset).max_length
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
self.blip_tokenizer = init_tokenizer()
self.blip_image_processor = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(224),
transforms.RandomCrop(224) if self.subset == 'train' else transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
])
def open_h5(self):
h5 = h5py.File(self.h5_file, "r")
self.h5 = h5[self.subset]
def __getitem__(self, index):
if not hasattr(self, 'h5'):
self.open_h5()
images = list()
for i in range(5):
im = self.h5['image{}'.format(i)][index]
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
images.append(im)
source_images = torch.stack([self.blip_image_processor(im) for im in images])
images = images[1:] if self.args.task == 'continuation' else images
images = [self.augment(im) for im in images]
images = torch.stack(images) if self.subset in ['train', 'val'] \
else torch.from_numpy(np.array([np.array(im) for im in images])).permute(0, 3, 1, 2)
texts = self.h5['sis'][index].decode('utf-8').split('|')
# tokenize caption using default tokenizer
tokenized = self.clip_tokenizer(
texts[1:] if self.args.task == 'continuation' else texts,
padding="max_length",
max_length=self.max_length,
truncation=False,
return_tensors="pt",
)
captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
tokenized = self.blip_tokenizer(
texts,
padding="max_length",
max_length=self.max_length,
truncation=False,
return_tensors="pt",
)
source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
return images, captions, attention_mask, source_images, source_caption, source_attention_mask
def __len__(self):
if not hasattr(self, 'h5'):
self.open_h5()
return len(self.h5['sis'])