Spaces:
Build error
Build error
import bisect | |
import numpy as np | |
import albumentations | |
from PIL import Image | |
from torch.utils.data import Dataset, ConcatDataset | |
class ConcatDatasetWithIndex(ConcatDataset): | |
"""Modified from original pytorch code to return dataset idx""" | |
def __getitem__(self, idx): | |
if idx < 0: | |
if -idx > len(self): | |
raise ValueError("absolute value of index should not exceed dataset length") | |
idx = len(self) + idx | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
if dataset_idx == 0: | |
sample_idx = idx | |
else: | |
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
return self.datasets[dataset_idx][sample_idx], dataset_idx | |
class ImagePaths(Dataset): | |
def __init__(self, paths, size=None, random_crop=False, labels=None): | |
self.size = size | |
self.random_crop = random_crop | |
self.labels = dict() if labels is None else labels | |
self.labels["file_path_"] = paths | |
self._length = len(paths) | |
if self.size is not None and self.size > 0: | |
self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) | |
if not self.random_crop: | |
self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) | |
else: | |
self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) | |
self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) | |
else: | |
self.preprocessor = lambda **kwargs: kwargs | |
def __len__(self): | |
return self._length | |
def preprocess_image(self, image_path): | |
image = Image.open(image_path) | |
if not image.mode == "RGB": | |
image = image.convert("RGB") | |
image = np.array(image).astype(np.uint8) | |
image = self.preprocessor(image=image)["image"] | |
image = (image/127.5 - 1.0).astype(np.float32) | |
return image | |
def __getitem__(self, i): | |
example = dict() | |
example["image"] = self.preprocess_image(self.labels["file_path_"][i]) | |
for k in self.labels: | |
example[k] = self.labels[k][i] | |
return example | |
class NumpyPaths(ImagePaths): | |
def preprocess_image(self, image_path): | |
image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 | |
image = np.transpose(image, (1,2,0)) | |
image = Image.fromarray(image, mode="RGB") | |
image = np.array(image).astype(np.uint8) | |
image = self.preprocessor(image=image)["image"] | |
image = (image/127.5 - 1.0).astype(np.float32) | |
return image | |