import json import os.path from PIL import Image from torch.utils.data import DataLoader from transformers import CLIPProcessor from torchvision.transforms import transforms import pytorch_lightning as pl class WikiArtDataset(): def __init__(self, meta_file): super(WikiArtDataset, self).__init__() self.files = [] with open(meta_file, 'r') as f: js = json.load(f) for img_path in js: img_name = os.path.splitext(os.path.basename(img_path))[0] caption = img_name.split('_')[-1] caption = caption.split('-') j = len(caption) - 1 while j >= 0: if not caption[j].isdigit(): break j -= 1 if j < 0: continue sentence = ' '.join(caption[:j + 1]) self.files.append({'img_path': os.path.join('datasets/wikiart', img_path), 'sentence': sentence}) version = 'openai/clip-vit-large-patch14' self.processor = CLIPProcessor.from_pretrained(version) self.jpg_transform = transforms.Compose([ transforms.Resize(512), transforms.RandomCrop(512), transforms.ToTensor(), ]) def __getitem__(self, idx): file = self.files[idx] im = Image.open(file['img_path']) im_tensor = self.jpg_transform(im) clip_im = self.processor(images=im, return_tensors="pt")['pixel_values'][0] return {'jpg': im_tensor, 'style': clip_im, 'txt': file['sentence']} def __len__(self): return len(self.files) class WikiArtDataModule(pl.LightningDataModule): def __init__(self, meta_file, batch_size, num_workers): super(WikiArtDataModule, self).__init__() self.train_dataset = WikiArtDataset(meta_file) self.batch_size = batch_size self.num_workers = num_workers def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True)