"""Data transforms for the loaders """ import random import traceback from pathlib import Path import numpy as np import torch import torch.nn.functional as F from skimage.color import rgba2rgb from skimage.io import imread from torchvision import transforms as trsfs from torchvision.transforms.functional import ( adjust_brightness, adjust_contrast, adjust_saturation, ) from climategan.tutils import normalize def interpolation(task): if task in ["d", "m", "s"]: return {"mode": "nearest"} else: return {"mode": "bilinear", "align_corners": True} class Resize: def __init__(self, target_size, keep_aspect_ratio=False): """ Resize transform. Target_size can be an int or a tuple of ints, depending on whether both height and width should have the same final size or not. If keep_aspect_ratio is specified then target_size must be an int: the smallest dimension of x will be set to target_size and the largest dimension will be computed to the closest int keeping the original aspect ratio. e.g. >>> x = torch.rand(1, 3, 1200, 1800) >>> m = torch.rand(1, 1, 600, 600) >>> d = {"x": x, "m": m} >>> {k: v.shape for k, v in Resize(640, True)(d).items()} {"x": (1, 3, 640, 960), "m": (1, 1, 640, 960)} Args: target_size (int | tuple(int)): New size for the tensor keep_aspect_ratio (bool, optional): Whether or not to keep aspect ratio when resizing. Requires target_size to be an int. If keeping aspect ratio, smallest dim will be set to target_size. Defaults to False. """ if isinstance(target_size, (int, tuple, list)): if not isinstance(target_size, int) and not keep_aspect_ratio: assert len(target_size) == 2 self.h, self.w = target_size else: if keep_aspect_ratio: assert isinstance(target_size, int) self.h = self.w = target_size self.default_h = int(self.h) self.default_w = int(self.w) self.sizes = {} elif isinstance(target_size, dict): assert ( not keep_aspect_ratio ), "dict target_size not compatible with keep_aspect_ratio" self.sizes = { k: {"h": v, "w": v} for k, v in target_size.items() if k != "default" } self.default_h = int(target_size["default"]) self.default_w = int(target_size["default"]) self.keep_aspect_ratio = keep_aspect_ratio def compute_new_default_size(self, tensor): """ compute the new size for a tensor depending on target size and keep_aspect_rato Args: tensor (torch.Tensor): 4D tensor N x C x H x W. Returns: tuple(int): (new_height, new_width) """ if self.keep_aspect_ratio: h, w = tensor.shape[-2:] if h < w: return (self.h, int(self.default_h * w / h)) else: return (int(self.default_h * h / w), self.default_w) return (self.default_h, self.default_w) def compute_new_size_for_task(self, task): assert ( not self.keep_aspect_ratio ), "compute_new_size_for_task is not compatible with keep aspect ratio" if task not in self.sizes: return (self.default_h, self.default_w) return (self.sizes[task]["h"], self.sizes[task]["w"]) def __call__(self, data): """ Resize a dict of tensors to the "x" key's new_size Args: data (dict[str:torch.Tensor]): The data dict to transform Returns: dict[str: torch.Tensor]: dict with all tensors resized to the new size of the data["x"] tensor """ task = tensor = new_size = None try: if not self.sizes: d = {} new_size = self.compute_new_default_size( data["x"] if "x" in data else list(data.values())[0] ) for task, tensor in data.items(): d[task] = F.interpolate( tensor, size=new_size, **interpolation(task) ) return d d = {} for task, tensor in data.items(): new_size = self.compute_new_size_for_task(task) d[task] = F.interpolate(tensor, size=new_size, **interpolation(task)) return d except Exception as e: tb = traceback.format_exc() print("Debug: task, shape, interpolation, h, w, new_size") print(task) print(tensor.shape) print(interpolation(task)) print(self.h, self.w) print(new_size) print(tb) raise Exception(e) class RandomCrop: def __init__(self, size, center=False): assert isinstance(size, (int, tuple, list)) if not isinstance(size, int): assert len(size) == 2 self.h, self.w = size else: self.h = self.w = size self.h = int(self.h) self.w = int(self.w) self.center = center def __call__(self, data): H, W = ( data["x"].size()[-2:] if "x" in data else list(data.values())[0].size()[-2:] ) if not self.center: top = np.random.randint(0, H - self.h) left = np.random.randint(0, W - self.w) else: top = (H - self.h) // 2 left = (W - self.w) // 2 return { task: tensor[:, :, top : top + self.h, left : left + self.w] for task, tensor in data.items() } class RandomHorizontalFlip: def __init__(self, p=0.5): # self.flip = TF.hflip self.p = p def __call__(self, data): if np.random.rand() > self.p: return data return {task: torch.flip(tensor, [3]) for task, tensor in data.items()} class ToTensor: def __init__(self): self.ImagetoTensor = trsfs.ToTensor() self.MaptoTensor = self.ImagetoTensor def __call__(self, data): new_data = {} for task, im in data.items(): if task in {"x", "a"}: new_data[task] = self.ImagetoTensor(im) elif task in {"m"}: new_data[task] = self.MaptoTensor(im) elif task == "s": new_data[task] = torch.squeeze(torch.from_numpy(np.array(im))).to( torch.int64 ) elif task == "d": new_data = im return new_data class Normalize: def __init__(self, opts): if opts.data.normalization == "HRNet": self.normImage = trsfs.Normalize( ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ) else: self.normImage = trsfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) self.normDepth = lambda x: x self.normMask = lambda x: x self.normSeg = lambda x: x self.normalize = { "x": self.normImage, "s": self.normSeg, "d": self.normDepth, "m": self.normMask, } def __call__(self, data): return { task: self.normalize.get(task, lambda x: x)(tensor.squeeze(0)) for task, tensor in data.items() } class RandBrightness: # Input need to be between -1 and 1 def __call__(self, data): return { task: rand_brightness(tensor) if task == "x" else tensor for task, tensor in data.items() } class RandSaturation: def __call__(self, data): return { task: rand_saturation(tensor) if task == "x" else tensor for task, tensor in data.items() } class RandContrast: def __call__(self, data): return { task: rand_contrast(tensor) if task == "x" else tensor for task, tensor in data.items() } class BucketizeDepth: def __init__(self, opts, domain): self.domain = domain if opts.gen.d.classify.enable and domain in {"s", "kitti"}: self.buckets = torch.linspace( *[ opts.gen.d.classify.linspace.min, opts.gen.d.classify.linspace.max, opts.gen.d.classify.linspace.buckets - 1, ] ) self.transforms = { "d": lambda tensor: torch.bucketize( tensor, self.buckets, out_int32=True, right=True ) } else: self.transforms = {} def __call__(self, data): return { task: self.transforms.get(task, lambda x: x)(tensor) for task, tensor in data.items() } class PrepareInference: """ Transform which: - transforms a str or an array into a tensor - resizes the image to keep the aspect ratio - crops in the center of the resized image - normalize to 0:1 - rescale to -1:1 """ def __init__(self, target_size=640, half=False, is_label=False, enforce_128=True): if enforce_128: if target_size % 2 ** 7 != 0: raise ValueError( f"Received a target_size of {target_size}, which is not a " + "multiple of 2^7 = 128. Set enforce_128 to False to disable " + "this error." ) self.resize = Resize(target_size, keep_aspect_ratio=True) self.crop = RandomCrop((target_size, target_size), center=True) self.half = half self.is_label = is_label def process(self, t): if isinstance(t, (str, Path)): t = imread(str(t)) if isinstance(t, np.ndarray): if t.shape[-1] == 4: t = rgba2rgb(t) t = torch.from_numpy(t) if t.ndim == 3: t = t.permute(2, 0, 1) if t.ndim == 3: t = t.unsqueeze(0) elif t.ndim == 2: t = t.unsqueeze(0).unsqueeze(0) if not self.is_label: t = t.to(torch.float32) t = normalize(t) t = (t - 0.5) * 2 t = {"m": t} if self.is_label else {"x": t} t = self.resize(t) t = self.crop(t) t = t["m"] if self.is_label else t["x"] if self.half and not self.is_label: t = t.half() return t def __call__(self, x): """ normalize, rescale, resize, crop in the center x can be: dict {"task": data} list [data, ..] or data data ^ can be a str, a Path, a numpy arrray or a Tensor """ if isinstance(x, dict): return {k: self.process(v) for k, v in x.items()} if isinstance(x, list): return [self.process(t) for t in x] return self.process(x) class PrepareTest: """ Transform which: - transforms a str or an array into a tensor - resizes the image to keep the aspect ratio - crops in the center of the resized image - normalize to 0:1 (optional) - rescale to -1:1 (optional) """ def __init__(self, target_size=640, half=False): self.resize = Resize(target_size, keep_aspect_ratio=True) self.crop = RandomCrop((target_size, target_size), center=True) self.half = half def process(self, t, normalize=False, rescale=False): if isinstance(t, (str, Path)): # t = img_as_float(imread(str(t))) t = imread(str(t)) if t.shape[-1] == 4: # t = rgba2rgb(t) t = t[:, :, :3] if np.ndim(t) == 2: t = np.repeat(t[:, :, np.newaxis], 3, axis=2) if isinstance(t, np.ndarray): t = torch.from_numpy(t) t = t.permute(2, 0, 1) if len(t.shape) == 3: t = t.unsqueeze(0) t = t.to(torch.float32) normalize(t) if normalize else t (t - 0.5) * 2 if rescale else t t = {"x": t} t = self.resize(t) t = self.crop(t) t = t["x"] if self.half: return t.to(torch.float16) return t def __call__(self, x, normalize=False, rescale=False): """ Call process() x can be: dict {"task": data} list [data, ..] or data data ^ can be a str, a Path, a numpy arrray or a Tensor """ if isinstance(x, dict): return {k: self.process(v, normalize, rescale) for k, v in x.items()} if isinstance(x, list): return [self.process(t, normalize, rescale) for t in x] return self.process(x, normalize, rescale) def get_transform(transform_item, mode): """Returns the torchivion transform function associated to a transform_item listed in opts.data.transforms ; transform_item is an addict.Dict """ if transform_item.name == "crop" and not ( transform_item.ignore is True or transform_item.ignore == mode ): return RandomCrop( (transform_item.height, transform_item.width), center=transform_item.center == mode, ) elif transform_item.name == "resize" and not ( transform_item.ignore is True or transform_item.ignore == mode ): return Resize( transform_item.new_size, transform_item.get("keep_aspect_ratio", False) ) elif transform_item.name == "hflip" and not ( transform_item.ignore is True or transform_item.ignore == mode ): return RandomHorizontalFlip(p=transform_item.p or 0.5) elif transform_item.name == "brightness" and not ( transform_item.ignore is True or transform_item.ignore == mode ): return RandBrightness() elif transform_item.name == "saturation" and not ( transform_item.ignore is True or transform_item.ignore == mode ): return RandSaturation() elif transform_item.name == "contrast" and not ( transform_item.ignore is True or transform_item.ignore == mode ): return RandContrast() elif transform_item.ignore is True or transform_item.ignore == mode: return None raise ValueError("Unknown transform_item {}".format(transform_item)) def get_transforms(opts, mode, domain): """Get all the transform functions listed in opts.data.transforms using get_transform(transform_item, mode) """ transforms = [] color_jittering_transforms = ["brightness", "saturation", "contrast"] for t in opts.data.transforms: if t.name not in color_jittering_transforms: transforms.append(get_transform(t, mode)) if "p" not in opts.tasks and mode == "train": for t in opts.data.transforms: if t.name in color_jittering_transforms: transforms.append(get_transform(t, mode)) transforms += [Normalize(opts), BucketizeDepth(opts, domain)] transforms = [t for t in transforms if t is not None] return transforms # ----- Adapted functions from https://github.com/mit-han-lab/data-efficient-gans -----# def rand_brightness(tensor, is_diff_augment=False): if is_diff_augment: assert len(tensor.shape) == 4 type_ = tensor.dtype device_ = tensor.device rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_) return tensor + (rand_tens - 0.5) else: factor = random.uniform(0.5, 1.5) tensor = adjust_brightness(tensor, brightness_factor=factor) # dummy pixels to fool scaling and preserve range tensor[:, :, 0, 0] = 1.0 tensor[:, :, -1, -1] = 0.0 return tensor def rand_saturation(tensor, is_diff_augment=False): if is_diff_augment: assert len(tensor.shape) == 4 type_ = tensor.dtype device_ = tensor.device rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_) x_mean = tensor.mean(dim=1, keepdim=True) return (tensor - x_mean) * (rand_tens * 2) + x_mean else: factor = random.uniform(0.5, 1.5) tensor = adjust_saturation(tensor, saturation_factor=factor) # dummy pixels to fool scaling and preserve range tensor[:, :, 0, 0] = 1.0 tensor[:, :, -1, -1] = 0.0 return tensor def rand_contrast(tensor, is_diff_augment=False): if is_diff_augment: assert len(tensor.shape) == 4 type_ = tensor.dtype device_ = tensor.device rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_) x_mean = tensor.mean(dim=[1, 2, 3], keepdim=True) return (tensor - x_mean) * (rand_tens + 0.5) + x_mean else: factor = random.uniform(0.5, 1.5) tensor = adjust_contrast(tensor, contrast_factor=factor) # dummy pixels to fool scaling and preserve range tensor[:, :, 0, 0] = 1.0 tensor[:, :, -1, -1] = 0.0 return tensor def rand_cutout(tensor, ratio=0.5): assert len(tensor.shape) == 4, "For rand cutout, tensor must be 4D." type_ = tensor.dtype device_ = tensor.device cutout_size = int(tensor.size(-2) * ratio + 0.5), int(tensor.size(-1) * ratio + 0.5) grid_batch, grid_x, grid_y = torch.meshgrid( torch.arange(tensor.size(0), dtype=torch.long, device=device_), torch.arange(cutout_size[0], dtype=torch.long, device=device_), torch.arange(cutout_size[1], dtype=torch.long, device=device_), ) size_ = [tensor.size(0), 1, 1] offset_x = torch.randint( 0, tensor.size(-2) + (1 - cutout_size[0] % 2), size=size_, device=device_, ) offset_y = torch.randint( 0, tensor.size(-1) + (1 - cutout_size[1] % 2), size=size_, device=device_, ) grid_x = torch.clamp( grid_x + offset_x - cutout_size[0] // 2, min=0, max=tensor.size(-2) - 1 ) grid_y = torch.clamp( grid_y + offset_y - cutout_size[1] // 2, min=0, max=tensor.size(-1) - 1 ) mask = torch.ones( tensor.size(0), tensor.size(2), tensor.size(3), dtype=type_, device=device_ ) mask[grid_batch, grid_x, grid_y] = 0 return tensor * mask.unsqueeze(1) def rand_translation(tensor, ratio=0.125): assert len(tensor.shape) == 4, "For rand translation, tensor must be 4D." device_ = tensor.device shift_x, shift_y = ( int(tensor.size(2) * ratio + 0.5), int(tensor.size(3) * ratio + 0.5), ) translation_x = torch.randint( -shift_x, shift_x + 1, size=[tensor.size(0), 1, 1], device=device_ ) translation_y = torch.randint( -shift_y, shift_y + 1, size=[tensor.size(0), 1, 1], device=device_ ) grid_batch, grid_x, grid_y = torch.meshgrid( torch.arange(tensor.size(0), dtype=torch.long, device=device_), torch.arange(tensor.size(2), dtype=torch.long, device=device_), torch.arange(tensor.size(3), dtype=torch.long, device=device_), ) grid_x = torch.clamp(grid_x + translation_x + 1, 0, tensor.size(2) + 1) grid_y = torch.clamp(grid_y + translation_y + 1, 0, tensor.size(3) + 1) x_pad = F.pad(tensor, [1, 1, 1, 1, 0, 0, 0, 0]) tensor = ( x_pad.permute(0, 2, 3, 1) .contiguous()[grid_batch, grid_x, grid_y] .permute(0, 3, 1, 2) ) return tensor class DiffTransforms: def __init__(self, diff_aug_opts): self.do_color_jittering = diff_aug_opts.do_color_jittering self.do_cutout = diff_aug_opts.do_cutout self.do_translation = diff_aug_opts.do_translation self.cutout_ratio = diff_aug_opts.cutout_ratio self.translation_ratio = diff_aug_opts.translation_ratio def __call__(self, tensor): if self.do_color_jittering: tensor = rand_brightness(tensor, is_diff_augment=True) tensor = rand_contrast(tensor, is_diff_augment=True) tensor = rand_saturation(tensor, is_diff_augment=True) if self.do_translation: tensor = rand_translation(tensor, ratio=self.translation_ratio) if self.do_cutout: tensor = rand_cutout(tensor, ratio=self.cutout_ratio) return tensor