import math import random import functools import operator import torch import torchvision from torch import nn from torch.nn import functional as F from torch.autograd import Function from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d class PixelNorm(nn.Module): def __init__(self): super().__init__() def forward(self, input): return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) class To4d(nn.Module): def __init__(self): super().__init__() def forward(self, input): return input.view(*input.size(),1,1) def make_kernel(k): k = torch.tensor(k, dtype=torch.float32) if k.ndim == 1: k = k[None, :] * k[:, None] k /= k.sum() return k class Upsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) * (factor ** 2) self.register_buffer('kernel', kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) return out class Downsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) self.register_buffer('kernel', kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) return out class Blur(nn.Module): def __init__(self, kernel, pad, upsample_factor=1): super().__init__() kernel = make_kernel(kernel) if upsample_factor > 1: kernel = kernel * (upsample_factor ** 2) self.register_buffer('kernel', kernel) self.pad = pad def forward(self, input): out = upfirdn2d(input, self.kernel, pad=self.pad) return out class EqualConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, groups=1, stride=1, padding=0, bias=True, lr_mul=1 ): super().__init__() self.weight = nn.Parameter( torch.randn(out_channel, in_channel//groups, kernel_size, kernel_size).div_(lr_mul) ) self.scale = lr_mul / math.sqrt((in_channel//groups) * kernel_size ** 2) self.stride = stride self.padding = padding self.groups = groups self.lr_mul =lr_mul if bias: self.bias = nn.Parameter(torch.zeros(out_channel)) else: self.bias = None def forward(self, input): bias = self.bias * self.lr_mul if self.bias is not None else None out = F.conv2d( input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups ) return out def __repr__(self): return ( f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' ) class EqualLinear(nn.Module): def __init__( self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None ): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) if bias: self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) else: self.bias = None self.activation = activation self.scale = (1 / math.sqrt(in_dim)) * lr_mul self.lr_mul = lr_mul def forward(self, input): weight = self.weight * self.scale bias = self.bias * self.lr_mul if self.activation: out = F.linear(input, weight) out = fused_leaky_relu(out, bias) else: out = F.linear( input, weight, bias=bias ) return out def __repr__(self): return ( f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' ) class ScaledLeakyReLU(nn.Module): def __init__(self, negative_slope=0.2): super().__init__() self.negative_slope = negative_slope def forward(self, input): out = F.leaky_relu(input, negative_slope=self.negative_slope) return out * math.sqrt(2) class ModulatedConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim, demodulate=True, upsample=False, downsample=False, blur_kernel=[1, 3, 3, 1], ): super().__init__() self.eps = 1e-8 self.kernel_size = kernel_size self.in_channel = in_channel self.out_channel = out_channel self.upsample = upsample self.downsample = downsample if upsample: factor = 2 p = (len(blur_kernel) - factor) - (kernel_size - 1) pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 + 1 self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 self.blur = Blur(blur_kernel, pad=(pad0, pad1)) fan_in = in_channel * kernel_size ** 2 self.scale = 1 / math.sqrt(fan_in) self.padding = kernel_size // 2 self.weight = nn.Parameter( torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) ) self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) self.demodulate = demodulate def __repr__(self): return ( f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' f'upsample={self.upsample}, downsample={self.downsample})' ) def get_latent(self, style): style = self.modulation(style) return style def forward(self, input, style, weights_delta=None): batch, in_channel, height, width = input.shape # style = self.modulation(style).view(batch, 1, in_channel, 1, 1) style = style.view(batch, 1, in_channel, 1, 1) if weights_delta is None: weight = self.scale * self.weight * style else: weight = self.scale * (self.weight * (1 + weights_delta) * style) if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view( batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size ) if self.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view( batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size ) weight = weight.transpose(1, 2).reshape( batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size ) out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) elif self.downsample: input = self.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) return out class NoiseInjection(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.zeros(1)) def forward(self, image, noise=None): if noise is None: batch, _, height, width = image.shape noise = image.new_empty(batch, 1, height, width).normal_() return image + self.weight * noise class ConstantInput(nn.Module): def __init__(self, channel, size=4): super().__init__() self.input = nn.Parameter(torch.randn(1, channel, size, size)) def forward(self, input): batch = input.shape[0] out = self.input.repeat(batch, 1, 1, 1) return out class StyledConv(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1], demodulate=True, ): super().__init__() self.conv = ModulatedConv2d( in_channel, out_channel, kernel_size, style_dim, upsample=upsample, blur_kernel=blur_kernel, demodulate=demodulate, ) self.noise = NoiseInjection() # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) # self.activate = ScaledLeakyReLU(0.2) self.activate = FusedLeakyReLU(out_channel) def get_latent(self, style): return self.conv.get_latent(style) def forward(self, input, style, noise=None, weights_delta=None): out_t = self.conv(input, style, weights_delta=weights_delta) out = self.noise(out_t, noise=noise) # out = out + self.bias out = self.activate(out) return out, out_t class ToRGB(nn.Module): def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): super().__init__() if upsample: self.upsample = Upsample(blur_kernel) self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) def get_latent(self, style): return self.conv.get_latent(style) def forward(self, input, style, skip=None, weights_delta=None): out = self.conv(input, style, weights_delta) out = out + self.bias if skip is not None: skip = self.upsample(skip) out = out + skip return out class Generator(nn.Module): def __init__( self, size, style_dim, n_mlp, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, ): super().__init__() self.size = size self.style_dim = style_dim layers = [PixelNorm()] for i in range(n_mlp): layers.append( EqualLinear( style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' ) ) self.style = nn.Sequential(*layers) self.channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } self.input = ConstantInput(self.channels[4]) self.conv1 = StyledConv( self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel ) self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) self.log_size = int(math.log(size, 2)) self.num_layers = (self.log_size - 2) * 2 + 1 self.convs = nn.ModuleList() self.upsamples = nn.ModuleList() self.to_rgbs = nn.ModuleList() self.noises = nn.Module() in_channel = self.channels[4] for layer_idx in range(self.num_layers): res = (layer_idx + 5) // 2 shape = [1, 1, 2 ** res, 2 ** res] self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) for i in range(3, self.log_size + 1): out_channel = self.channels[2 ** i] self.convs.append( StyledConv( in_channel, out_channel, 3, style_dim, upsample=True, blur_kernel=blur_kernel, ) ) self.convs.append( StyledConv( out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel ) ) self.to_rgbs.append(ToRGB(out_channel, style_dim)) in_channel = out_channel self.n_latent = self.log_size * 2 - 2 def make_noise(self): device = self.input.input.device noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] for i in range(3, self.log_size + 1): for _ in range(2): noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) return noises def mean_latent(self, n_latent): latent_in = torch.randn( n_latent, self.style_dim, device=self.input.input.device) latent = self.get_latent(latent_in)#.mean(0, keepdim=True) latent = [latent[i].mean(0, keepdim=True) for i in range(len(latent))] return latent def get_w(self, input): latent = self.style(input) latent = fused_leaky_relu(latent, torch.zeros_like(latent).cuda(), 5.) return latent def get_latent(self, input, is_latent=False, truncation=1, mean_latent=None): output = [] if not is_latent: latent = self.style(input) latent = latent.unsqueeze(1).repeat(1, self.n_latent, 1) #[B, 14, 512] else: latent = input output.append(self.conv1.get_latent(latent[:, 0])) output.append(self.to_rgb1.get_latent(latent[:, 1])) i = 1 for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2], self.to_rgbs): output.append(conv1.get_latent(latent[:, i])) output.append(conv2.get_latent(latent[:, i+1])) output.append(to_rgb.get_latent(latent[:, i+2])) i += 2 # output = torch.cat(output, 1) if truncation < 1 and mean_latent is not None: output = [mean_latent[i] + truncation * (output[i] - mean_latent[i]) for i in range(len(output))] return output def forward( self, styles, stop_idx=99, is_cluster=False, noise=None, randomize_noise=False, weights_deltas=None, ): total_convs = len(self.convs) + len(self.to_rgbs) +2 if weights_deltas is None: weights_deltas = [None]* total_convs if noise is None: if randomize_noise: noise = [None] * self.num_layers else: noise = [ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) ] outputs = [] idx_count = 0 latent = styles out = self.input(latent[0]) outputs.append([out, out]) if idx_count == stop_idx: return outputs out, out_t = self.conv1(out, latent[idx_count], noise=noise[0],weights_delta=weights_deltas[0]) outputs.append([out_t, out]) idx_count += 1 if idx_count == stop_idx: return outputs skip = self.to_rgb1(out, latent[idx_count], weights_delta=weights_deltas[1]) i = 1 weight_idx = 2 for conv1, conv2, noise1, noise2, to_rgb in zip( self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs ): outputs.append([out_t, out]) idx_count += 1 if idx_count == stop_idx: return outputs out, out_t = conv1(out, latent[idx_count], noise=noise1, weights_delta=weights_deltas[weight_idx]) outputs.append([out_t, out]) idx_count += 1 if idx_count == stop_idx: return outputs out, out_t = conv2(out, latent[idx_count], noise=noise2, weights_delta=weights_deltas[weight_idx+1]) outputs.append([out_t, out]) idx_count += 1 if idx_count == stop_idx: return outputs skip = to_rgb(out, latent[idx_count], skip, weights_delta=weights_deltas[weight_idx+2]) i += 2 weight_idx += 3 image = skip.clamp(-1,1) return image, outputs class ConvLayer(nn.Sequential): def __init__( self, in_channel, out_channel, kernel_size, groups=1, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, lr_mul=1, ): layers = [] if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 layers.append(Blur(blur_kernel, pad=(pad0, pad1))) stride = 2 self.padding = 0 else: stride = 1 self.padding = kernel_size // 2 layers.append( EqualConv2d( in_channel, out_channel, kernel_size, groups=groups, padding=self.padding, stride=stride, bias=bias and not activate, lr_mul=lr_mul, ) ) if activate: if bias: layers.append(FusedLeakyReLU(out_channel, lr_mul=lr_mul)) else: layers.append(ScaledLeakyReLU(0.2)) super().__init__(*layers) class ResBlock(nn.Module): def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): super().__init__() self.conv1 = ConvLayer(in_channel, in_channel, 3) self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) self.skip = ConvLayer( in_channel, out_channel, 1, downsample=True, activate=False, bias=False ) def forward(self, input): out = self.conv1(input) out = self.conv2(out) skip = self.skip(input) out = (out + skip) / math.sqrt(2) return out class Discriminator(nn.Module): def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): super().__init__() channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } convs = [ConvLayer(3, channels[size], 1)] log_size = int(math.log(size, 2)) in_channel = channels[size] for i in range(log_size, 2, -1): out_channel = channels[2 ** (i - 1)] convs.append(ResBlock(in_channel, out_channel, blur_kernel)) in_channel = out_channel self.convs = nn.Sequential(*convs) self.stddev_group = 4 self.stddev_feat = 1 self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) self.final_linear = nn.Sequential( EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), EqualLinear(channels[4], 1), ) def forward(self, input): out = self.convs(input) batch, channel, height, width = out.shape group = min(batch, self.stddev_group) #group = batch stddev = out.view( group, -1, self.stddev_feat, channel // self.stddev_feat, height, width ) stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) stddev = stddev.repeat(group, 1, height, width) out = torch.cat([out, stddev], 1) out = self.final_conv(out) out = out.view(batch, -1) out = self.final_linear(out) return out class VGGExtractor(torch.nn.Module): def __init__(self, resize=False): super(VGGExtractor, self).__init__() vgg16 = torchvision.models.vgg16(pretrained=True).eval() blocks = vgg16.features[:23] for p in blocks: p.requires_grad = False self.blocks = blocks self.transform = torch.nn.functional.interpolate self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) self.resize = resize def forward(self, input): if input.shape[1] != 3: input = input.repeat(1, 3, 1, 1) input = (input + 1) / 2 input = (input-self.mean) / self.std if self.resize: input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) return self.blocks(input) class Encoder(nn.Module): def __init__(self, size, groups, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]): ''' [16]: [14,15,16,17,18,19] [8]: [8,9,10,11,12,13] [4]: [0,1,2,3,4,5,6,7] ''' super().__init__() in_channel = 3 out_channel = 64 convs = nn.ModuleList() for i in range(6): convs.append(ResBlock(in_channel, out_channel, blur_kernel)) in_channel = out_channel out_channel = min(1024, in_channel*2) self.fc_high = nn.Sequential(nn.AdaptiveAvgPool2d(4), nn.Flatten(), EqualLinear(512*4*4, 4*512+3*256+2*128)) self.fc_mid = nn.Sequential(nn.AdaptiveAvgPool2d(4), nn.Flatten(), EqualLinear(1024*4*4, 512*6)) self.fc_low = nn.Sequential(nn.AdaptiveAvgPool2d(4), nn.Flatten(), EqualLinear(1024*4*4, 512*5)) def forward(self, input): shared = self.convs(input) local = self.local_fc(shared) glob = self.global_fc(shared) return local.view(local.size(0), -1), glob