Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import random | |
import torch | |
import torch.utils.data as data | |
import torchvision.transforms as transforms | |
class BaseDataset(data.Dataset): | |
def __init__(self): | |
super(BaseDataset, self).__init__() | |
def name(self): | |
return "BaseDataset" | |
def initialize(self, opt): | |
pass | |
class Rescale_fixed(object): | |
"""Rescale the input image into given size. | |
Args: | |
(w,h) (tuple): output size or x (int) then resized will be done in (x,x). | |
""" | |
def __init__(self, output_size): | |
self.output_size = output_size | |
def __call__(self, image): | |
return image.resize(self.output_size, Image.BICUBIC) | |
class Rescale_custom(object): | |
"""Rescale the input image and target image into randomly selected size with lower bound of min_size arg. | |
Args: | |
min_size (int): Minimum desired output size. | |
""" | |
def __init__(self, min_size, max_size): | |
assert isinstance(min_size, (int, float)) | |
self.min_size = min_size | |
self.max_size = max_size | |
def __call__(self, sample): | |
input_image, target_image = sample["input_image"], sample["target_image"] | |
assert input_image.size == target_image.size | |
w, h = input_image.size | |
# Randomly select size to resize | |
if min(self.max_size, h, w) > self.min_size: | |
self.output_size = np.random.randint( | |
self.min_size, min(self.max_size, h, w) | |
) | |
else: | |
self.output_size = self.min_size | |
# calculate new size by keeping aspect ratio same | |
if h > w: | |
new_h, new_w = self.output_size * h / w, self.output_size | |
else: | |
new_h, new_w = self.output_size, self.output_size * w / h | |
new_w, new_h = int(new_w), int(new_h) | |
input_image = input_image.resize((new_w, new_h), Image.BICUBIC) | |
target_image = target_image.resize((new_w, new_h), Image.BICUBIC) | |
return {"input_image": input_image, "target_image": target_image} | |
class ToTensor(object): | |
"""Convert ndarrays in sample to Tensors.""" | |
def __init__(self): | |
self.totensor = transforms.ToTensor() | |
def __call__(self, sample): | |
input_image, target_image = sample["input_image"], sample["target_image"] | |
return { | |
"input_image": self.totensor(input_image), | |
"target_image": self.totensor(target_image), | |
} | |
class RandomCrop_custom(object): | |
"""Crop randomly the image in a sample. | |
Args: | |
output_size (tuple or int): Desired output size. If int, square crop | |
is made. | |
""" | |
def __init__(self, output_size): | |
assert isinstance(output_size, (int, tuple)) | |
if isinstance(output_size, int): | |
self.output_size = (output_size, output_size) | |
else: | |
assert len(output_size) == 2 | |
self.output_size = output_size | |
self.randomcrop = transforms.RandomCrop(self.output_size) | |
def __call__(self, sample): | |
input_image, target_image = sample["input_image"], sample["target_image"] | |
cropped_imgs = self.randomcrop(torch.cat((input_image, target_image))) | |
return { | |
"input_image": cropped_imgs[ | |
:3, | |
:, | |
], | |
"target_image": cropped_imgs[ | |
3:, | |
:, | |
], | |
} | |
class Normalize_custom(object): | |
"""Normalize given dict into given mean and standard dev | |
Args: | |
mean (tuple or int): Desired mean to substract from dict's tensors | |
std (tuple or int): Desired std to divide from dict's tensors | |
""" | |
def __init__(self, mean, std): | |
assert isinstance(mean, (float, tuple)) | |
if isinstance(mean, float): | |
self.mean = (mean, mean, mean) | |
else: | |
assert len(mean) == 3 | |
self.mean = mean | |
if isinstance(std, float): | |
self.std = (std, std, std) | |
else: | |
assert len(std) == 3 | |
self.std = std | |
self.normalize = transforms.Normalize(self.mean, self.std) | |
def __call__(self, sample): | |
input_image, target_image = sample["input_image"], sample["target_image"] | |
return { | |
"input_image": self.normalize(input_image), | |
"target_image": self.normalize(target_image), | |
} | |
class Normalize_image(object): | |
"""Normalize given tensor into given mean and standard dev | |
Args: | |
mean (float): Desired mean to substract from tensors | |
std (float): Desired std to divide from tensors | |
""" | |
def __init__(self, mean, std): | |
assert isinstance(mean, (float)) | |
if isinstance(mean, float): | |
self.mean = mean | |
if isinstance(std, float): | |
self.std = std | |
self.normalize_1 = transforms.Normalize(self.mean, self.std) | |
self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3) | |
self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18) | |
def __call__(self, image_tensor): | |
if image_tensor.shape[0] == 1: | |
return self.normalize_1(image_tensor) | |
elif image_tensor.shape[0] == 3: | |
return self.normalize_3(image_tensor) | |
elif image_tensor.shape[0] == 18: | |
return self.normalize_18(image_tensor) | |
else: | |
assert "Please set proper channels! Normlization implemented only for 1, 3 and 18" | |