import json import os import random from torch.utils.data import Dataset from PIL import Image from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True Image.MAX_IMAGE_PIXELS = None from data.utils import pre_caption import os,glob class pretrain_dataset(Dataset): def __init__(self, ann_file, laion_path, transform): self.ann_pretrain = [] for f in ann_file: print('loading '+f) ann = json.load(open(f,'r')) self.ann_pretrain += ann self.laion_path = laion_path if self.laion_path: self.laion_files = glob.glob(os.path.join(laion_path,'*.json')) print('loading '+self.laion_files[0]) with open(self.laion_files[0],'r') as f: self.ann_laion = json.load(f) self.annotation = self.ann_pretrain + self.ann_laion else: self.annotation = self.ann_pretrain self.transform = transform def reload_laion(self, epoch): n = epoch%len(self.laion_files) print('loading '+self.laion_files[n]) with open(self.laion_files[n],'r') as f: self.ann_laion = json.load(f) self.annotation = self.ann_pretrain + self.ann_laion def __len__(self): return len(self.annotation) def __getitem__(self, index): ann = self.annotation[index] image = Image.open(ann['image']).convert('RGB') image = self.transform(image) caption = pre_caption(ann['caption'],30) return image, caption