Anyou's picture
Upload 8 files
b2b0303
import cv2
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import CLIPTokenizer
from models.blip_override.blip import init_tokenizer
class StoryDataset(Dataset):
"""
A custom subset class for the LRW (includes train, val, test) subset
"""
def __init__(self, subset, args):
super(StoryDataset, self).__init__()
self.args = args
self.h5_file = args.get(args.dataset).hdf5_file
self.subset = subset
self.augment = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(512),
transforms.RandomCrop(512) if self.subset == 'train' else transforms.CenterCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
]) if self.subset in ['train', 'val'] else transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(64),
transforms.CenterCrop(64)
])
self.dataset = args.dataset
self.max_length = args.get(args.dataset).max_length
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
self.blip_tokenizer = init_tokenizer()
self.blip_image_processor = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(224),
transforms.RandomCrop(224) if self.subset == 'train' else transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
])
def open_h5(self):
h5 = h5py.File(self.h5_file, "r")
self.h5 = h5[self.subset]
def __getitem__(self, index):
if not hasattr(self, 'h5'):
self.open_h5()
images = list()
for i in range(5):
im = self.h5['image{}'.format(i)][index]
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
images.append(im)
source_images = torch.stack([self.blip_image_processor(im) for im in images])
images = images[1:] if self.args.task == 'continuation' else images
images = [self.augment(im) for im in images]
images = torch.stack(images) if self.subset in ['train', 'val'] \
else torch.from_numpy(np.array([np.array(im) for im in images])).permute(0, 3, 1, 2)
texts = self.h5['sis'][index].decode('utf-8').split('|')
# tokenize caption using default tokenizer
tokenized = self.clip_tokenizer(
texts[1:] if self.args.task == 'continuation' else texts,
padding="max_length",
max_length=self.max_length,
truncation=False,
return_tensors="pt",
)
captions, attention_mask = tokenized['input_ids'], tokenized['attention_mask']
tokenized = self.blip_tokenizer(
texts,
padding="max_length",
max_length=self.max_length,
truncation=False,
return_tensors="pt",
)
source_caption, source_attention_mask = tokenized['input_ids'], tokenized['attention_mask']
return images, captions, attention_mask, source_images, source_caption, source_attention_mask
def __len__(self):
if not hasattr(self, 'h5'):
self.open_h5()
return len(self.h5['sis'])