Spaces:
Runtime error
Runtime error
import json | |
import os.path | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
from transformers import CLIPProcessor | |
from torchvision.transforms import transforms | |
import pytorch_lightning as pl | |
class WikiArtDataset(): | |
def __init__(self, meta_file): | |
super(WikiArtDataset, self).__init__() | |
self.files = [] | |
with open(meta_file, 'r') as f: | |
js = json.load(f) | |
for img_path in js: | |
img_name = os.path.splitext(os.path.basename(img_path))[0] | |
caption = img_name.split('_')[-1] | |
caption = caption.split('-') | |
j = len(caption) - 1 | |
while j >= 0: | |
if not caption[j].isdigit(): | |
break | |
j -= 1 | |
if j < 0: | |
continue | |
sentence = ' '.join(caption[:j + 1]) | |
self.files.append({'img_path': os.path.join('datasets/wikiart', img_path), 'sentence': sentence}) | |
version = 'openai/clip-vit-large-patch14' | |
self.processor = CLIPProcessor.from_pretrained(version) | |
self.jpg_transform = transforms.Compose([ | |
transforms.Resize(512), | |
transforms.RandomCrop(512), | |
transforms.ToTensor(), | |
]) | |
def __getitem__(self, idx): | |
file = self.files[idx] | |
im = Image.open(file['img_path']) | |
im_tensor = self.jpg_transform(im) | |
clip_im = self.processor(images=im, return_tensors="pt")['pixel_values'][0] | |
return {'jpg': im_tensor, 'style': clip_im, 'txt': file['sentence']} | |
def __len__(self): | |
return len(self.files) | |
class WikiArtDataModule(pl.LightningDataModule): | |
def __init__(self, meta_file, batch_size, num_workers): | |
super(WikiArtDataModule, self).__init__() | |
self.train_dataset = WikiArtDataset(meta_file) | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
def train_dataloader(self): | |
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, | |
pin_memory=True) | |