# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # All contributions by Andy Brock: # Copyright (c) 2019 Andy Brock # # MIT License # # BigGAN V1: # This is now deprecated code used for porting the TFHub modules to pytorch, # included here for reference only. import numpy as np import torch from scipy.stats import truncnorm from torch import nn from torch.nn import Parameter from torch.nn import functional as F def l2normalize(v, eps=1e-4): return v / (v.norm() + eps) def truncated_z_sample(batch_size, z_dim, truncation=0.5, seed=None): state = None if seed is None else np.random.RandomState(seed) values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim), random_state=state) return truncation * values def denorm(x): out = (x + 1) / 2 return out.clamp_(0, 1) class SpectralNorm(nn.Module): def __init__(self, module, name="weight", power_iterations=1): super(SpectralNorm, self).__init__() self.module = module self.name = name self.power_iterations = power_iterations if not self._made_params(): self._make_params() def _update_u_v(self): u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") height = w.data.shape[0] _w = w.view(height, -1) for _ in range(self.power_iterations): v = l2normalize(torch.matmul(_w.t(), u)) u = l2normalize(torch.matmul(_w, v)) sigma = u.dot((_w).mv(v)) setattr(self.module, self.name, w / sigma.expand_as(w)) def _made_params(self): try: getattr(self.module, self.name + "_u") getattr(self.module, self.name + "_v") getattr(self.module, self.name + "_bar") return True except AttributeError: return False def _make_params(self): w = getattr(self.module, self.name) height = w.data.shape[0] width = w.view(height, -1).data.shape[1] u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) u.data = l2normalize(u.data) v.data = l2normalize(v.data) w_bar = Parameter(w.data) del self.module._parameters[self.name] self.module.register_parameter(self.name + "_u", u) self.module.register_parameter(self.name + "_v", v) self.module.register_parameter(self.name + "_bar", w_bar) def forward(self, *args): self._update_u_v() return self.module.forward(*args) class SelfAttention(nn.Module): """ Self Attention Layer""" def __init__(self, in_dim, activation=F.relu): super().__init__() self.chanel_in = in_dim self.activation = activation self.theta = SpectralNorm( nn.Conv2d( in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False ) ) self.phi = SpectralNorm( nn.Conv2d( in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False ) ) self.pool = nn.MaxPool2d(2, 2) self.g = SpectralNorm( nn.Conv2d( in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1, bias=False ) ) self.o_conv = SpectralNorm( nn.Conv2d( in_channels=in_dim // 2, out_channels=in_dim, kernel_size=1, bias=False ) ) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) def forward(self, x): m_batchsize, C, width, height = x.size() N = height * width theta = self.theta(x) phi = self.phi(x) phi = self.pool(phi) phi = phi.view(m_batchsize, -1, N // 4) theta = theta.view(m_batchsize, -1, N) theta = theta.permute(0, 2, 1) attention = self.softmax(torch.bmm(theta, phi)) g = self.pool(self.g(x)).view(m_batchsize, -1, N // 4) attn_g = torch.bmm(g, attention.permute(0, 2, 1)).view( m_batchsize, -1, width, height ) out = self.o_conv(attn_g) return self.gamma * out + x class ConditionalBatchNorm2d(nn.Module): def __init__(self, num_features, num_classes, eps=1e-4, momentum=0.1): super().__init__() self.num_features = num_features self.bn = nn.BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum) self.gamma_embed = SpectralNorm( nn.Linear(num_classes, num_features, bias=False) ) self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False)) def forward(self, x, y): out = self.bn(x) gamma = self.gamma_embed(y) + 1 beta = self.beta_embed(y) out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view( -1, self.num_features, 1, 1 ) return out class GBlock(nn.Module): def __init__( self, in_channel, out_channel, kernel_size=[3, 3], padding=1, stride=1, n_class=None, bn=True, activation=F.relu, upsample=True, downsample=False, z_dim=148, ): super().__init__() self.conv0 = SpectralNorm( nn.Conv2d( in_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True, ) ) self.conv1 = SpectralNorm( nn.Conv2d( out_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True, ) ) self.skip_proj = False if in_channel != out_channel or upsample or downsample: self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0)) self.skip_proj = True self.upsample = upsample self.downsample = downsample self.activation = activation self.bn = bn if bn: self.HyperBN = ConditionalBatchNorm2d(in_channel, z_dim) self.HyperBN_1 = ConditionalBatchNorm2d(out_channel, z_dim) def forward(self, input, condition=None): out = input if self.bn: out = self.HyperBN(out, condition) out = self.activation(out) if self.upsample: out = F.interpolate(out, scale_factor=2) out = self.conv0(out) if self.bn: out = self.HyperBN_1(out, condition) out = self.activation(out) out = self.conv1(out) if self.downsample: out = F.avg_pool2d(out, 2) if self.skip_proj: skip = input if self.upsample: skip = F.interpolate(skip, scale_factor=2) skip = self.conv_sc(skip) if self.downsample: skip = F.avg_pool2d(skip, 2) else: skip = input return out + skip class Generator128(nn.Module): def __init__(self, code_dim=120, n_class=1000, chn=96, debug=False): super().__init__() self.linear = nn.Linear(n_class, 128, bias=False) if debug: chn = 8 self.first_view = 16 * chn self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn)) z_dim = code_dim + 28 self.GBlock = nn.ModuleList( [ GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim), GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim), GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim), GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), ] ) self.sa_id = 4 self.num_split = len(self.GBlock) + 1 self.attention = SelfAttention(2 * chn) self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4) self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) def forward(self, input, class_id): codes = torch.chunk(input, self.num_split, 1) class_emb = self.linear(class_id) # 128 out = self.G_linear(codes[0]) out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): if i == self.sa_id: out = self.attention(out) condition = torch.cat([code, class_emb], 1) out = GBlock(out, condition) out = self.ScaledCrossReplicaBN(out) out = F.relu(out) out = self.colorize(out) return torch.tanh(out) class Generator256(nn.Module): def __init__(self, code_dim=140, n_class=1000, chn=96, debug=False): super().__init__() self.linear = nn.Linear(n_class, 128, bias=False) if debug: chn = 8 self.first_view = 16 * chn self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn)) self.GBlock = nn.ModuleList( [ GBlock(16 * chn, 16 * chn, n_class=n_class), GBlock(16 * chn, 8 * chn, n_class=n_class), GBlock(8 * chn, 8 * chn, n_class=n_class), GBlock(8 * chn, 4 * chn, n_class=n_class), GBlock(4 * chn, 2 * chn, n_class=n_class), GBlock(2 * chn, 1 * chn, n_class=n_class), ] ) self.sa_id = 5 self.num_split = len(self.GBlock) + 1 self.attention = SelfAttention(2 * chn) self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4) self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) def forward(self, input, class_id): codes = torch.chunk(input, self.num_split, 1) class_emb = self.linear(class_id) # 128 out = self.G_linear(codes[0]) out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): if i == self.sa_id: out = self.attention(out) condition = torch.cat([code, class_emb], 1) out = GBlock(out, condition) out = self.ScaledCrossReplicaBN(out) out = F.relu(out) out = self.colorize(out) return torch.tanh(out) class Generator512(nn.Module): def __init__(self, code_dim=128, n_class=1000, chn=96, debug=False): super().__init__() self.linear = nn.Linear(n_class, 128, bias=False) if debug: chn = 8 self.first_view = 16 * chn self.G_linear = SpectralNorm(nn.Linear(16, 4 * 4 * 16 * chn)) z_dim = code_dim + 16 self.GBlock = nn.ModuleList( [ GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim), GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), GBlock(8 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim), GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim), GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), GBlock(1 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), ] ) self.sa_id = 4 self.num_split = len(self.GBlock) + 1 self.attention = SelfAttention(4 * chn) self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn) self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) def forward(self, input, class_id): codes = torch.chunk(input, self.num_split, 1) class_emb = self.linear(class_id) # 128 out = self.G_linear(codes[0]) out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): if i == self.sa_id: out = self.attention(out) condition = torch.cat([code, class_emb], 1) out = GBlock(out, condition) out = self.ScaledCrossReplicaBN(out) out = F.relu(out) out = self.colorize(out) return torch.tanh(out) class Discriminator(nn.Module): def __init__(self, n_class=1000, chn=96, debug=False): super().__init__() def conv(in_channel, out_channel, downsample=True): return GBlock( in_channel, out_channel, bn=False, upsample=False, downsample=downsample ) if debug: chn = 8 self.debug = debug self.pre_conv = nn.Sequential( SpectralNorm(nn.Conv2d(3, 1 * chn, 3, padding=1)), nn.ReLU(), SpectralNorm(nn.Conv2d(1 * chn, 1 * chn, 3, padding=1)), nn.AvgPool2d(2), ) self.pre_skip = SpectralNorm(nn.Conv2d(3, 1 * chn, 1)) self.conv = nn.Sequential( conv(1 * chn, 1 * chn, downsample=True), conv(1 * chn, 2 * chn, downsample=True), SelfAttention(2 * chn), conv(2 * chn, 2 * chn, downsample=True), conv(2 * chn, 4 * chn, downsample=True), conv(4 * chn, 8 * chn, downsample=True), conv(8 * chn, 8 * chn, downsample=True), conv(8 * chn, 16 * chn, downsample=True), conv(16 * chn, 16 * chn, downsample=False), ) self.linear = SpectralNorm(nn.Linear(16 * chn, 1)) self.embed = nn.Embedding(n_class, 16 * chn) self.embed.weight.data.uniform_(-0.1, 0.1) self.embed = SpectralNorm(self.embed) def forward(self, input, class_id): out = self.pre_conv(input) out += self.pre_skip(F.avg_pool2d(input, 2)) out = self.conv(out) out = F.relu(out) out = out.view(out.size(0), out.size(1), -1) out = out.sum(2) out_linear = self.linear(out).squeeze(1) embed = self.embed(class_id) prod = (out * embed).sum(1) return out_linear + prod