#@title Define Generator and Discriminator model import numpy as np import torch from torch import nn from torch.nn import Parameter from torch.nn import functional as F def l2_normalize(v, dim=None, eps=1e-12): return v / (v.norm(dim=dim, keepdim=True) + eps) def unpool(value): """Unpooling operation. N-dimensional version of the unpooling operation from https://www.robots.ox.ac.uk/~vgg/rg/papers/Dosovitskiy_Learning_to_Generate_2015_CVPR_paper.pdf Taken from: https://github.com/tensorflow/tensorflow/issues/2169 Args: value: a Tensor of shape [b, d0, d1, ..., dn, ch] name: name of the op Returns: A Tensor of shape [b, 2*d0, 2*d1, ..., 2*dn, ch] """ value = torch.Tensor.permute(value, [0,2,3,1]) sh = list(value.shape) dim = len(sh[1:-1]) out = (torch.reshape(value, [-1] + sh[-dim:])) for i in range(dim, 0, -1): out = torch.cat([out, torch.zeros_like(out)], i) out_size = [-1] + [s * 2 for s in sh[1:-1]] + [sh[-1]] out = torch.reshape(out, out_size) out = torch.Tensor.permute(out, [0,3,1,2]) return out class BatchNorm2d(nn.BatchNorm2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.initialized = False self.accumulating = False self.accumulated_mean = Parameter(torch.zeros(args[0]), requires_grad=False) self.accumulated_var = Parameter(torch.zeros(args[0]), requires_grad=False) self.accumulated_counter = Parameter(torch.zeros(1)+1e-12, requires_grad=False) def forward(self, inputs, *args, **kwargs): if not self.initialized: self.check_accumulation() self.set_initialized(True) if self.accumulating: self.eval() with torch.no_grad(): axes = [0] + ([] if len(inputs.shape) == 2 else list(range(2,len(inputs.shape)))) _mean = torch.mean(inputs, axes, keepdim=True) mean = torch.mean(inputs, axes, keepdim=False) var = torch.mean((inputs-_mean)**2, axes) self.accumulated_mean.copy_(self.accumulated_mean + mean) self.accumulated_var.copy_(self.accumulated_var + var) self.accumulated_counter.copy_(self.accumulated_counter + 1) _mean = self.running_mean*1.0 _variance = self.running_var*1.0 self._mean.copy_(self.accumulated_mean / self.accumulated_counter) self._variance.copy_(self.accumulated_var / self.accumulated_counter) out = super().forward(inputs, *args, **kwargs) self.running_mean.copy_(_mean) self.running_var.copy_(_variance) return out out = super().forward(inputs, *args, **kwargs) return out def check_accumulation(self): if self.accumulated_counter.detach().cpu().numpy().mean() > 1-1e-12: self.running_mean.copy_(self.accumulated_mean / self.accumulated_counter) self.running_var.copy_(self.accumulated_var / self.accumulated_counter) return True return False def clear_accumulated(self): self.accumulated_mean.copy_(self.accumulated_mean*0.0) self.accumulated_var.copy_(self.accumulated_var*0.0) self.accumulated_counter.copy_(self.accumulated_counter*0.0+1e-2) def set_accumulating(self, status=True): if status: self.accumulating = True else: self.accumulating = False def set_initialized(self, status=False): if not status: self.initialized = False else: self.initialized = True class SpectralNorm(nn.Module): def __init__(self, module, name='weight', power_iterations=2): super().__init__() self.module = module self.name = name self.power_iterations = power_iterations if not self._made_params(): self._make_params() def _update_u(self): w = self.weight u = self.weight_u if len(w.shape) == 4: _w = torch.Tensor.permute(w, [2,3,1,0]) _w = torch.reshape(_w, [-1, _w.shape[-1]]) elif isinstance(self.module, nn.Linear) or isinstance(self.module, nn.Embedding): _w = torch.Tensor.permute(w, [1,0]) _w = torch.reshape(_w, [-1, _w.shape[-1]]) else: _w = torch.reshape(w, [-1, w.shape[-1]]) _w = torch.reshape(_w, [-1, _w.shape[-1]]) singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right" norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1 for _ in range(self.power_iterations): if singular_value == "left": v = l2_normalize(torch.matmul(_w.t(), u), dim=norm_dim) u = l2_normalize(torch.matmul(_w, v), dim=norm_dim) else: v = l2_normalize(torch.matmul(u, _w.t()), dim=norm_dim) u = l2_normalize(torch.matmul(v, _w), dim=norm_dim) if singular_value == "left": sigma = torch.matmul(torch.matmul(u.t(), _w), v) else: sigma = torch.matmul(torch.matmul(v, _w), u.t()) _w = w / sigma.detach() setattr(self.module, self.name, _w) self.weight_u.copy_(u.detach()) def _made_params(self): try: self.weight self.weight_u return True except AttributeError: return False def _make_params(self): w = getattr(self.module, self.name) if len(w.shape) == 4: _w = torch.Tensor.permute(w, [2,3,1,0]) _w = torch.reshape(_w, [-1, _w.shape[-1]]) elif isinstance(self.module, nn.Linear) or isinstance(self.module, nn.Embedding): _w = torch.Tensor.permute(w, [1,0]) _w = torch.reshape(_w, [-1, _w.shape[-1]]) else: _w = torch.reshape(w, [-1, w.shape[-1]]) singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right" norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1 u_shape = (_w.shape[0], 1) if singular_value == "left" else (1, _w.shape[-1]) u = Parameter(w.data.new(*u_shape).normal_(0, 1), requires_grad=False) u.copy_(l2_normalize(u, dim=norm_dim).detach()) del self.module._parameters[self.name] self.weight = w self.weight_u = u def forward(self, *args, **kwargs): self._update_u() return self.module.forward(*args, **kwargs) class SelfAttention(nn.Module): def __init__(self, in_dim, activation=torch.relu): super().__init__() self.chanel_in = in_dim self.activation = activation self.theta = SpectralNorm(nn.Conv2d(in_dim, in_dim // 8, 1, bias=False)) self.phi = SpectralNorm(nn.Conv2d(in_dim, in_dim // 8, 1, bias=False)) self.pool = nn.MaxPool2d(2, 2) self.g = SpectralNorm(nn.Conv2d(in_dim, in_dim // 2, 1, bias=False)) self.o_conv = SpectralNorm(nn.Conv2d(in_dim // 2, in_dim, 1, bias=False)) self.gamma = Parameter(torch.zeros(1)) def forward(self, x): m_batchsize, C, width, height = x.shape N = height * width theta = self.theta(x) phi = self.phi(x) phi = self.pool(phi) phi = torch.reshape(phi,(m_batchsize, -1, N // 4)) theta = torch.reshape(theta,(m_batchsize, -1, N)) theta = torch.Tensor.permute(theta,(0, 2, 1)) attention = torch.softmax(torch.bmm(theta, phi), -1) g = self.g(x) g = torch.reshape(self.pool(g),(m_batchsize, -1, N // 4)) attn_g = torch.reshape(torch.bmm(g, torch.Tensor.permute(attention,(0, 2, 1))),(m_batchsize, -1, width, height)) out = self.o_conv(attn_g) return self.gamma * out + x class ConditionalBatchNorm2d(nn.Module): def __init__(self, num_features, num_classes, eps=1e-5, momentum=0.1): super().__init__() self.bn_in_cond = BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum) self.gamma_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False)) self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False)) def forward(self, x, y): out = self.bn_in_cond(x) if isinstance(y, list): gamma, beta = y out = torch.reshape(gamma, (gamma.shape[0], -1, 1, 1)) * out + torch.reshape(beta, (beta.shape[0], -1, 1, 1)) return out gamma = self.gamma_embed(y) # gamma = gamma + 1 beta = self.beta_embed(y) out = torch.reshape(gamma, (gamma.shape[0], -1, 1, 1)) * out + torch.reshape(beta, (beta.shape[0], -1, 1, 1)) return out class ResBlock(nn.Module): def __init__( self, in_channel, out_channel, kernel_size=[3, 3], padding=1, stride=1, n_class=None, conditional=True, activation=torch.relu, upsample=True, downsample=False, z_dim=128, use_attention=False, skip_proj=None ): super().__init__() if conditional: self.cond_norm1 = ConditionalBatchNorm2d(in_channel, z_dim) self.conv0 = SpectralNorm( nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding) ) if conditional: self.cond_norm2 = ConditionalBatchNorm2d(out_channel, z_dim) self.conv1 = SpectralNorm( nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding) ) self.skip_proj = False if skip_proj is not True and (upsample or downsample): self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0)) self.skip_proj = True if use_attention: self.attention = SelfAttention(out_channel) self.upsample = upsample self.downsample = downsample self.activation = activation self.conditional = conditional self.use_attention = use_attention def forward(self, input, condition=None): out = input if self.conditional: out = self.cond_norm1(out, condition if not isinstance(condition, list) else condition[0]) out = self.activation(out) if self.upsample: out = unpool(out) # out = F.interpolate(out, scale_factor=2) out = self.conv0(out) if self.conditional: out = self.cond_norm2(out, condition if not isinstance(condition, list) else condition[1]) out = self.activation(out) out = self.conv1(out) if self.downsample: out = F.avg_pool2d(out, 2, 2) if self.skip_proj: skip = input if self.upsample: skip = unpool(skip) # skip = F.interpolate(skip, scale_factor=2) skip = self.conv_sc(skip) if self.downsample: skip = F.avg_pool2d(skip, 2, 2) out = out + skip else: skip = input if self.use_attention: out = self.attention(out) return out class Generator(nn.Module): def __init__(self, code_dim=128, n_class=1000, chn=96, blocks_with_attention="B4", resolution=512): super().__init__() def GBlock(in_channel, out_channel, n_class, z_dim, use_attention): return ResBlock(in_channel, out_channel, n_class=n_class, z_dim=z_dim, use_attention=use_attention) self.embed_y = nn.Linear(n_class, 128, bias=False) self.chn = chn self.resolution = resolution self.blocks_with_attention = set(blocks_with_attention.split(",")) self.blocks_with_attention.discard('') gblock = [] in_channels, out_channels = self.get_in_out_channels() self.num_split = len(in_channels) + 1 z_dim = code_dim//self.num_split + 128 self.noise_fc = SpectralNorm(nn.Linear(code_dim//self.num_split, 4 * 4 * in_channels[0])) self.sa_ids = [int(s.split('B')[-1]) for s in self.blocks_with_attention] for i, (nc_in, nc_out) in enumerate(zip(in_channels, out_channels)): gblock.append(GBlock(nc_in, nc_out, n_class=n_class, z_dim=z_dim, use_attention=(i+1) in self.sa_ids)) self.blocks = nn.ModuleList(gblock) self.output_layer_bn = BatchNorm2d(1 * chn, eps=1e-5) self.output_layer_conv = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) self.z_dim = code_dim self.c_dim = n_class self.n_level = self.num_split def get_in_out_channels(self): resolution = self.resolution if resolution == 1024: channel_multipliers = [16, 16, 8, 8, 4, 2, 1, 1, 1] elif resolution == 512: channel_multipliers = [16, 16, 8, 8, 4, 2, 1, 1] elif resolution == 256: channel_multipliers = [16, 16, 8, 8, 4, 2, 1] elif resolution == 128: channel_multipliers = [16, 16, 8, 4, 2, 1] elif resolution == 64: channel_multipliers = [16, 16, 8, 4, 2] elif resolution == 32: channel_multipliers = [4, 4, 4, 4] else: raise ValueError("Unsupported resolution: {}".format(resolution)) in_channels = [self.chn * c for c in channel_multipliers[:-1]] out_channels = [self.chn * c for c in channel_multipliers[1:]] return in_channels, out_channels def forward(self, input, class_id): codes = torch.chunk(input, self.num_split, 1) class_emb = self.embed_y(class_id) # 128 out = self.noise_fc(codes[0]) out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2)) for i, (code, gblock) in enumerate(zip(codes[1:], self.blocks)): condition = torch.cat([code, class_emb], 1) out = gblock(out, condition) out = self.output_layer_bn(out) out = torch.relu(out) out = self.output_layer_conv(out) return (torch.tanh(out) + 1) / 2 def forward_w(self, ws): out = self.noise_fc(ws[0]) out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2)) for i, (w, gblock) in enumerate(zip(ws[1:], self.blocks)): out = gblock(out, w) out = self.output_layer_bn(out) out = torch.relu(out) out = self.output_layer_conv(out) return (torch.tanh(out) + 1) / 2 def forward_wp(self, z0, gammas, betas): out = self.noise_fc(z0) out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2)) for i, (gamma, beta, gblock) in enumerate(zip(gammas, betas, self.blocks)): out = gblock(out, [[gamma[0], beta[0]], [gamma[1], beta[1]]]) out = self.output_layer_bn(out) out = torch.relu(out) out = self.output_layer_conv(out) return (torch.tanh(out) + 1) / 2