from functools import reduce import math import operator import numpy as np from skimage import transform import torch from torch import nn def translate2d(tx, ty): mat = [[1, 0, tx], [0, 1, ty], [0, 0, 1]] return torch.tensor(mat, dtype=torch.float32) def scale2d(sx, sy): mat = [[sx, 0, 0], [ 0, sy, 0], [ 0, 0, 1]] return torch.tensor(mat, dtype=torch.float32) def rotate2d(theta): mat = [[torch.cos(theta), torch.sin(-theta), 0], [torch.sin(theta), torch.cos(theta), 0], [ 0, 0, 1]] return torch.tensor(mat, dtype=torch.float32) class KarrasAugmentationPipeline: def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8, disable_all=False): self.a_prob = a_prob self.a_scale = a_scale self.a_aniso = a_aniso self.a_trans = a_trans self.disable_all = disable_all def __call__(self, image): h, w = image.size mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)] # x-flip a0 = torch.randint(2, []).float() mats.append(scale2d(1 - 2 * a0, 1)) # y-flip do = (torch.rand([]) < self.a_prob).float() a1 = torch.randint(2, []).float() * do mats.append(scale2d(1, 1 - 2 * a1)) # scaling do = (torch.rand([]) < self.a_prob).float() a2 = torch.randn([]) * do mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2)) # rotation do = (torch.rand([]) < self.a_prob).float() a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do mats.append(rotate2d(-a3)) # anisotropy do = (torch.rand([]) < self.a_prob).float() a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do a5 = torch.randn([]) * do mats.append(rotate2d(a4)) mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5)) mats.append(rotate2d(-a4)) # translation do = (torch.rand([]) < self.a_prob).float() a6 = torch.randn([]) * do a7 = torch.randn([]) * do mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7)) # form the transformation matrix and conditioning vector mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5)) mat = reduce(operator.matmul, mats) cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) # apply the transformation image_orig = np.array(image, dtype=np.float32) / 255 if image_orig.ndim == 2: image_orig = image_orig[..., None] tf = transform.AffineTransform(mat.numpy()) if not self.disable_all: image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) else: image = image_orig cond = torch.zeros_like(cond) image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1 image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 return image, image_orig, cond class KarrasAugmentWrapper(nn.Module): def __init__(self, model): super().__init__() self.inner_model = model def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs): if aug_cond is None: aug_cond = input.new_zeros([input.shape[0], 9]) if mapping_cond is None: mapping_cond = aug_cond else: mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1) return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs) def param_groups(self, *args, **kwargs): return self.inner_model.param_groups(*args, **kwargs) def set_skip_stages(self, skip_stages): return self.inner_model.set_skip_stages(skip_stages) def set_patch_size(self, patch_size): return self.inner_model.set_patch_size(patch_size)