Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import pandas as pd | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
import json | |
import random | |
def c_crop(image): | |
width, height = image.size | |
new_size = min(width, height) | |
left = (width - new_size) / 2 | |
top = (height - new_size) / 2 | |
right = (width + new_size) / 2 | |
bottom = (height + new_size) / 2 | |
return image.crop((left, top, right, bottom)) | |
class CustomImageDataset(Dataset): | |
def __init__(self, img_dir, img_size=512): | |
self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] | |
self.images.sort() | |
self.img_size = img_size | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, idx): | |
try: | |
img = Image.open(self.images[idx]) | |
img = c_crop(img) | |
img = img.resize((self.img_size, self.img_size)) | |
img = torch.from_numpy((np.array(img) / 127.5) - 1) | |
img = img.permute(2, 0, 1) | |
json_path = self.images[idx].split('.')[0] + '.json' | |
prompt = json.load(open(json_path))['caption'] | |
return img, prompt | |
except Exception as e: | |
print(e) | |
return self.__getitem__(random.randint(0, len(self.images) - 1)) | |
def loader(train_batch_size, num_workers, **args): | |
dataset = CustomImageDataset(**args) | |
return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True) |