turn-the-cam-anonymous's picture
adding CLIP taming
1ed7deb
import bisect
import numpy as np
import albumentations
from PIL import Image
from torch.utils.data import Dataset, ConcatDataset
class ConcatDatasetWithIndex(ConcatDataset):
"""Modified from original pytorch code to return dataset idx"""
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx], dataset_idx
class ImagePaths(Dataset):
def __init__(self, paths, size=None, random_crop=False, labels=None):
self.size = size
self.random_crop = random_crop
self.labels = dict() if labels is None else labels
self.labels["file_path_"] = paths
self._length = len(paths)
if self.size is not None and self.size > 0:
self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
if not self.random_crop:
self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
else:
self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
else:
self.preprocessor = lambda **kwargs: kwargs
def __len__(self):
return self._length
def preprocess_image(self, image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
image = self.preprocessor(image=image)["image"]
image = (image/127.5 - 1.0).astype(np.float32)
return image
def __getitem__(self, i):
example = dict()
example["image"] = self.preprocess_image(self.labels["file_path_"][i])
for k in self.labels:
example[k] = self.labels[k][i]
return example
class NumpyPaths(ImagePaths):
def preprocess_image(self, image_path):
image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
image = np.transpose(image, (1,2,0))
image = Image.fromarray(image, mode="RGB")
image = np.array(image).astype(np.uint8)
image = self.preprocessor(image=image)["image"]
image = (image/127.5 - 1.0).astype(np.float32)
return image