import torch import torchvision from torchvision import transforms import torch.nn.functional as F import torchvision.transforms.functional as TF import numpy as np from PIL import Image, ImageFilter import random class BaseTransform(object): """ Resize and center crop. """ def __init__(self, res): self.res = res def __call__(self, index, image): image = TF.resize(image, self.res, Image.BILINEAR) w, h = image.size left = int(round((w - self.res) / 2.)) top = int(round((h - self.res) / 2.)) return TF.crop(image, top, left, self.res, self.res) class ComposeTransform(object): def __init__(self, tlist): self.tlist = tlist def __call__(self, index, image): for trans in self.tlist: image = trans(index, image) return image class RandomResize(object): def __init__(self, rmin, rmax, N): self.reslist = [random.randint(rmin, rmax) for _ in range(N)] def __call__(self, index, image): return TF.resize(image, self.reslist[index], Image.BILINEAR) class RandomCrop(object): def __init__(self, res, N): self.res = res self.cons = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)] def __call__(self, index, image): ws, hs = self.cons[index] w, h = image.size left = int(round((w-self.res)*ws)) top = int(round((h-self.res)*hs)) return TF.crop(image, top, left, self.res, self.res) class RandomHorizontalFlip(object): def __init__(self, N, p=0.5): self.p_ref = p self.plist = np.random.random_sample(N) def __call__(self, index, image): if self.plist[index.cpu()] < self.p_ref: return TF.hflip(image) else: return image class TensorTransform(object): def __init__(self): self.to_tensor = transforms.ToTensor() #self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) def __call__(self, image): image = self.to_tensor(image) #image = self.normalize(image) return image class RandomGaussianBlur(object): def __init__(self, sigma, p, N): self.min_x = sigma[0] self.max_x = sigma[1] self.del_p = 1 - p self.p_ref = p self.plist = np.random.random_sample(N) def __call__(self, index, image): if self.plist[index] < self.p_ref: x = self.plist[index] - self.p_ref m = (self.max_x - self.min_x) / self.del_p b = self.min_x s = m * x + b return image.filter(ImageFilter.GaussianBlur(radius=s)) else: return image class RandomGrayScale(object): def __init__(self, p, N): self.grayscale = transforms.RandomGrayscale(p=1.) # Deterministic (We still want flexible out_dim). self.p_ref = p self.plist = np.random.random_sample(N) def __call__(self, index, image): if self.plist[index] < self.p_ref: return self.grayscale(image) else: return image class RandomColorBrightness(object): def __init__(self, x, p, N): self.min_x = max(0, 1 - x) self.max_x = 1 + x self.p_ref = p self.plist = np.random.random_sample(N) self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)] def __call__(self, index, image): if self.plist[index] < self.p_ref: return TF.adjust_brightness(image, self.rlist[index]) else: return image class RandomColorContrast(object): def __init__(self, x, p, N): self.min_x = max(0, 1 - x) self.max_x = 1 + x self.p_ref = p self.plist = np.random.random_sample(N) self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)] def __call__(self, index, image): if self.plist[index] < self.p_ref: return TF.adjust_contrast(image, self.rlist[index]) else: return image class RandomColorSaturation(object): def __init__(self, x, p, N): self.min_x = max(0, 1 - x) self.max_x = 1 + x self.p_ref = p self.plist = np.random.random_sample(N) self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)] def __call__(self, index, image): if self.plist[index] < self.p_ref: return TF.adjust_saturation(image, self.rlist[index]) else: return image class RandomColorHue(object): def __init__(self, x, p, N): self.min_x = -x self.max_x = x self.p_ref = p self.plist = np.random.random_sample(N) self.rlist = [random.uniform(self.min_x, self.max_x) for _ in range(N)] def __call__(self, index, image): if self.plist[index] < self.p_ref: return TF.adjust_hue(image, self.rlist[index]) else: return image class RandomVerticalFlip(object): def __init__(self, N, p=0.5): self.p_ref = p self.plist = np.random.random_sample(N) def __call__(self, indice, image): I = np.nonzero(self.plist[indice] < self.p_ref)[0] if len(image.size()) == 3: image_t = image[I].flip([1]) else: image_t = image[I].flip([2]) return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) class RandomHorizontalTensorFlip(object): def __init__(self, N, p=0.5): self.p_ref = p self.plist = np.random.random_sample(N) def __call__(self, indice, image, is_label=False): I = np.nonzero(self.plist[indice] < self.p_ref)[0] if len(image.size()) == 3: image_t = image[I].flip([2]) else: image_t = image[I].flip([3]) return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) class RandomResizedCrop(object): def __init__(self, N, res, scale=(0.5, 1.0)): self.res = res self.scale = scale self.rscale = [np.random.uniform(*scale) for _ in range(N)] self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)] def random_crop(self, idx, img): ws, hs = self.rcrop[idx] res1 = int(img.size(-1)) res2 = int(self.rscale[idx]*res1) i1 = int(round((res1-res2)*ws)) j1 = int(round((res1-res2)*hs)) return img[:, :, i1:i1+res2, j1:j1+res2] def __call__(self, indice, image): new_image = [] res_tar = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2? for i, idx in enumerate(indice): img = image[[i]] img = self.random_crop(idx, img) img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False) new_image.append(img) new_image = torch.cat(new_image) return new_image