|  | from functools import reduce | 
					
						
						|  | import math | 
					
						
						|  | import operator | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | from skimage import transform | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def translate2d(tx, ty): | 
					
						
						|  | mat = [[1, 0, tx], | 
					
						
						|  | [0, 1, ty], | 
					
						
						|  | [0, 0,  1]] | 
					
						
						|  | return torch.tensor(mat, dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def scale2d(sx, sy): | 
					
						
						|  | mat = [[sx,  0, 0], | 
					
						
						|  | [ 0, sy, 0], | 
					
						
						|  | [ 0,  0, 1]] | 
					
						
						|  | return torch.tensor(mat, dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def rotate2d(theta): | 
					
						
						|  | mat = [[torch.cos(theta), torch.sin(-theta), 0], | 
					
						
						|  | [torch.sin(theta),  torch.cos(theta), 0], | 
					
						
						|  | [               0,                 0, 1]] | 
					
						
						|  | return torch.tensor(mat, dtype=torch.float32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class KarrasAugmentationPipeline: | 
					
						
						|  | def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8): | 
					
						
						|  | self.a_prob = a_prob | 
					
						
						|  | self.a_scale = a_scale | 
					
						
						|  | self.a_aniso = a_aniso | 
					
						
						|  | self.a_trans = a_trans | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, image): | 
					
						
						|  | h, w = image.size | 
					
						
						|  | mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | a0 = torch.randint(2, []).float() | 
					
						
						|  | mats.append(scale2d(1 - 2 * a0, 1)) | 
					
						
						|  |  | 
					
						
						|  | do = (torch.rand([]) < self.a_prob).float() | 
					
						
						|  | a1 = torch.randint(2, []).float() * do | 
					
						
						|  | mats.append(scale2d(1, 1 - 2 * a1)) | 
					
						
						|  |  | 
					
						
						|  | do = (torch.rand([]) < self.a_prob).float() | 
					
						
						|  | a2 = torch.randn([]) * do | 
					
						
						|  | mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2)) | 
					
						
						|  |  | 
					
						
						|  | do = (torch.rand([]) < self.a_prob).float() | 
					
						
						|  | a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do | 
					
						
						|  | mats.append(rotate2d(-a3)) | 
					
						
						|  |  | 
					
						
						|  | do = (torch.rand([]) < self.a_prob).float() | 
					
						
						|  | a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do | 
					
						
						|  | a5 = torch.randn([]) * do | 
					
						
						|  | mats.append(rotate2d(a4)) | 
					
						
						|  | mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5)) | 
					
						
						|  | mats.append(rotate2d(-a4)) | 
					
						
						|  |  | 
					
						
						|  | do = (torch.rand([]) < self.a_prob).float() | 
					
						
						|  | a6 = torch.randn([]) * do | 
					
						
						|  | a7 = torch.randn([]) * do | 
					
						
						|  | mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5)) | 
					
						
						|  | mat = reduce(operator.matmul, mats) | 
					
						
						|  | cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | image_orig = np.array(image, dtype=np.float32) / 255 | 
					
						
						|  | if image_orig.ndim == 2: | 
					
						
						|  | image_orig = image_orig[..., None] | 
					
						
						|  | tf = transform.AffineTransform(mat.numpy()) | 
					
						
						|  | image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) | 
					
						
						|  | image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1 | 
					
						
						|  | image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 | 
					
						
						|  | return image, image_orig, cond | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class KarrasAugmentWrapper(nn.Module): | 
					
						
						|  | def __init__(self, model): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.inner_model = model | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs): | 
					
						
						|  | if aug_cond is None: | 
					
						
						|  | aug_cond = input.new_zeros([input.shape[0], 9]) | 
					
						
						|  | if mapping_cond is None: | 
					
						
						|  | mapping_cond = aug_cond | 
					
						
						|  | else: | 
					
						
						|  | mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1) | 
					
						
						|  | return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs) | 
					
						
						|  |  | 
					
						
						|  | def set_skip_stages(self, skip_stages): | 
					
						
						|  | return self.inner_model.set_skip_stages(skip_stages) | 
					
						
						|  |  | 
					
						
						|  | def set_patch_size(self, patch_size): | 
					
						
						|  | return self.inner_model.set_patch_size(patch_size) | 
					
						
						|  |  |