kingpreyansh's picture
Added Stable Diffusion Files
fe0ca90
raw
history blame
No virus
2.61 kB
import bisect
import numpy as np
import albumentations
from PIL import Image
from torch.utils.data import Dataset, ConcatDataset
class ConcatDatasetWithIndex(ConcatDataset):
"""Modified from original pytorch code to return dataset idx"""
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx], dataset_idx
class ImagePaths(Dataset):
def __init__(self, paths, size=None, random_crop=False, labels=None):
self.size = size
self.random_crop = random_crop
self.labels = dict() if labels is None else labels
self.labels["file_path_"] = paths
self._length = len(paths)
if self.size is not None and self.size > 0:
self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
if not self.random_crop:
self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
else:
self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
else:
self.preprocessor = lambda **kwargs: kwargs
def __len__(self):
return self._length
def preprocess_image(self, image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
image = self.preprocessor(image=image)["image"]
image = (image/127.5 - 1.0).astype(np.float32)
return image
def __getitem__(self, i):
example = dict()
example["image"] = self.preprocess_image(self.labels["file_path_"][i])
for k in self.labels:
example[k] = self.labels[k][i]
return example
class NumpyPaths(ImagePaths):
def preprocess_image(self, image_path):
image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
image = np.transpose(image, (1,2,0))
image = Image.fromarray(image, mode="RGB")
image = np.array(image).astype(np.uint8)
image = self.preprocessor(image=image)["image"]
image = (image/127.5 - 1.0).astype(np.float32)
return image