Spaces:
Running
Running
import json | |
import os | |
import random | |
from torch.utils.data import Dataset | |
from PIL import Image | |
from PIL import ImageFile | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
Image.MAX_IMAGE_PIXELS = None | |
from data.utils import pre_caption | |
import os,glob | |
class pretrain_dataset(Dataset): | |
def __init__(self, ann_file, laion_path, transform): | |
self.ann_pretrain = [] | |
for f in ann_file: | |
print('loading '+f) | |
ann = json.load(open(f,'r')) | |
self.ann_pretrain += ann | |
self.laion_path = laion_path | |
if self.laion_path: | |
self.laion_files = glob.glob(os.path.join(laion_path,'*.json')) | |
print('loading '+self.laion_files[0]) | |
with open(self.laion_files[0],'r') as f: | |
self.ann_laion = json.load(f) | |
self.annotation = self.ann_pretrain + self.ann_laion | |
else: | |
self.annotation = self.ann_pretrain | |
self.transform = transform | |
def reload_laion(self, epoch): | |
n = epoch%len(self.laion_files) | |
print('loading '+self.laion_files[n]) | |
with open(self.laion_files[n],'r') as f: | |
self.ann_laion = json.load(f) | |
self.annotation = self.ann_pretrain + self.ann_laion | |
def __len__(self): | |
return len(self.annotation) | |
def __getitem__(self, index): | |
ann = self.annotation[index] | |
image = Image.open(ann['image']).convert('RGB') | |
image = self.transform(image) | |
caption = pre_caption(ann['caption'],30) | |
return image, caption |