T2I-Adapter / ldm /data /dataset_wikiart.py
LiangbinXie
add composable adapter
0177fec
raw
history blame
No virus
2.15 kB
import json
import os.path
from PIL import Image
from torch.utils.data import DataLoader
from transformers import CLIPProcessor
from torchvision.transforms import transforms
import pytorch_lightning as pl
class WikiArtDataset():
def __init__(self, meta_file):
super(WikiArtDataset, self).__init__()
self.files = []
with open(meta_file, 'r') as f:
js = json.load(f)
for img_path in js:
img_name = os.path.splitext(os.path.basename(img_path))[0]
caption = img_name.split('_')[-1]
caption = caption.split('-')
j = len(caption) - 1
while j >= 0:
if not caption[j].isdigit():
break
j -= 1
if j < 0:
continue
sentence = ' '.join(caption[:j + 1])
self.files.append({'img_path': os.path.join('datasets/wikiart', img_path), 'sentence': sentence})
version = 'openai/clip-vit-large-patch14'
self.processor = CLIPProcessor.from_pretrained(version)
self.jpg_transform = transforms.Compose([
transforms.Resize(512),
transforms.RandomCrop(512),
transforms.ToTensor(),
])
def __getitem__(self, idx):
file = self.files[idx]
im = Image.open(file['img_path'])
im_tensor = self.jpg_transform(im)
clip_im = self.processor(images=im, return_tensors="pt")['pixel_values'][0]
return {'jpg': im_tensor, 'style': clip_im, 'txt': file['sentence']}
def __len__(self):
return len(self.files)
class WikiArtDataModule(pl.LightningDataModule):
def __init__(self, meta_file, batch_size, num_workers):
super(WikiArtDataModule, self).__init__()
self.train_dataset = WikiArtDataset(meta_file)
self.batch_size = batch_size
self.num_workers = num_workers
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
pin_memory=True)