"""Normalization layers used in blocks """ import torch import torch.nn as nn import torch.nn.functional as F class AdaptiveInstanceNorm2d(nn.Module): def __init__(self, num_features, eps=1e-5, momentum=0.1): super(AdaptiveInstanceNorm2d, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum # weight and bias are dynamically assigned self.weight = None self.bias = None # just dummy buffers, not used self.register_buffer("running_mean", torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features)) def forward(self, x): assert ( self.weight is not None and self.bias is not None ), "Please assign weight and bias before calling AdaIN!" b, c = x.size(0), x.size(1) running_mean = self.running_mean.repeat(b) running_var = self.running_var.repeat(b) # Apply instance norm x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) out = F.batch_norm( x_reshaped, running_mean, running_var, self.weight, self.bias, True, self.momentum, self.eps, ) return out.view(b, c, *x.size()[2:]) def __repr__(self): return self.__class__.__name__ + "(" + str(self.num_features) + ")" class LayerNorm(nn.Module): def __init__(self, num_features, eps=1e-5, affine=True): super(LayerNorm, self).__init__() self.num_features = num_features self.affine = affine self.eps = eps if self.affine: self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) self.beta = nn.Parameter(torch.zeros(num_features)) def forward(self, x): shape = [-1] + [1] * (x.dim() - 1) # print(x.size()) if x.size(0) == 1: # These two lines run much faster in pytorch 0.4 # than the two lines listed below. mean = x.view(-1).mean().view(*shape) std = x.view(-1).std().view(*shape) else: mean = x.view(x.size(0), -1).mean(1).view(*shape) std = x.view(x.size(0), -1).std(1).view(*shape) x = (x - mean) / (std + self.eps) if self.affine: shape = [1, -1] + [1] * (x.dim() - 2) x = x * self.gamma.view(*shape) + self.beta.view(*shape) return x def l2normalize(v, eps=1e-12): return v / (v.norm() + eps) class SpectralNorm(nn.Module): """ Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida and the Pytorch implementation: https://github.com/christiancosgrove/pytorch-spectral-normalization-gan """ def __init__(self, module, name="weight", power_iterations=1): super().__init__() self.module = module self.name = name self.power_iterations = power_iterations if not self._made_params(): self._make_params() def _update_u_v(self): u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") height = w.data.shape[0] for _ in range(self.power_iterations): v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) sigma = u.dot(w.view(height, -1).mv(v)) setattr(self.module, self.name, w / sigma.expand_as(w)) def _made_params(self): try: u = getattr(self.module, self.name + "_u") # noqa: F841 v = getattr(self.module, self.name + "_v") # noqa: F841 w = getattr(self.module, self.name + "_bar") # noqa: F841 return True except AttributeError: return False def _make_params(self): w = getattr(self.module, self.name) height = w.data.shape[0] width = w.view(height, -1).data.shape[1] u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) u.data = l2normalize(u.data) v.data = l2normalize(v.data) w_bar = nn.Parameter(w.data) del self.module._parameters[self.name] self.module.register_parameter(self.name + "_u", u) self.module.register_parameter(self.name + "_v", v) self.module.register_parameter(self.name + "_bar", w_bar) def forward(self, *args): self._update_u_v() return self.module.forward(*args) class SPADE(nn.Module): def __init__(self, param_free_norm_type, kernel_size, norm_nc, cond_nc): super().__init__() if param_free_norm_type == "instance": self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) # elif param_free_norm_type == "syncbatch": # self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) elif param_free_norm_type == "batch": self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) else: raise ValueError( "%s is not a recognized param-free norm type in SPADE" % param_free_norm_type ) # The dimension of the intermediate embedding space. Yes, hardcoded. nhidden = 128 pw = kernel_size // 2 self.mlp_shared = nn.Sequential( nn.Conv2d(cond_nc, nhidden, kernel_size=kernel_size, padding=pw), nn.ReLU() ) self.mlp_gamma = nn.Conv2d( nhidden, norm_nc, kernel_size=kernel_size, padding=pw ) self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=kernel_size, padding=pw) def forward(self, x, segmap): # Part 1. generate parameter-free normalized activations normalized = self.param_free_norm(x) # Part 2. produce scaling and bias conditioned on semantic map segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") actv = self.mlp_shared(segmap) gamma = self.mlp_gamma(actv) beta = self.mlp_beta(actv) # apply scale and bias out = normalized * (1 + gamma) + beta return out