import numpy as np import torch import torch.nn.functional as F import torch.nn as nn import swapae.util as util from swapae.models.networks import BaseNetwork from swapae.models.networks.stylegan2_layers import ResBlock, ConvLayer, ToRGB, EqualLinear, Blur, Upsample, make_kernel from swapae.models.networks.stylegan2_op import upfirdn2d class ToSpatialCode(torch.nn.Module): def __init__(self, inch, outch, scale): super().__init__() hiddench = inch // 2 self.conv1 = ConvLayer(inch, hiddench, 1, activate=True, bias=True) self.conv2 = ConvLayer(hiddench, outch, 1, activate=False, bias=True) self.scale = scale self.upsample = Upsample([1, 3, 3, 1], 2) self.blur = Blur([1, 3, 3, 1], pad=(2, 1)) self.register_buffer('kernel', make_kernel([1, 3, 3, 1])) def forward(self, x): x = self.conv1(x) x = self.conv2(x) for i in range(int(np.log2(self.scale))): x = self.upsample(x) return x class StyleGAN2ResnetEncoder(BaseNetwork): @staticmethod def modify_commandline_options(parser, is_train): parser.add_argument("--netE_scale_capacity", default=1.0, type=float) parser.add_argument("--netE_num_downsampling_sp", default=4, type=int) parser.add_argument("--netE_num_downsampling_gl", default=2, type=int) parser.add_argument("--netE_nc_steepness", default=2.0, type=float) return parser def __init__(self, opt): super().__init__(opt) # If antialiasing is used, create a very lightweight Gaussian kernel. blur_kernel = [1, 2, 1] if self.opt.use_antialias else [1] self.add_module("FromRGB", ConvLayer(3, self.nc(0), 1)) self.DownToSpatialCode = nn.Sequential() for i in range(self.opt.netE_num_downsampling_sp): self.DownToSpatialCode.add_module( "ResBlockDownBy%d" % (2 ** i), ResBlock(self.nc(i), self.nc(i + 1), blur_kernel, reflection_pad=True) ) # Spatial Code refers to the Structure Code, and # Global Code refers to the Texture Code of the paper. nchannels = self.nc(self.opt.netE_num_downsampling_sp) self.add_module( "ToSpatialCode", nn.Sequential( ConvLayer(nchannels, nchannels, 1, activate=True, bias=True), ConvLayer(nchannels, self.opt.spatial_code_ch, kernel_size=1, activate=False, bias=True) ) ) self.DownToGlobalCode = nn.Sequential() for i in range(self.opt.netE_num_downsampling_gl): idx_from_beginning = self.opt.netE_num_downsampling_sp + i self.DownToGlobalCode.add_module( "ConvLayerDownBy%d" % (2 ** idx_from_beginning), ConvLayer(self.nc(idx_from_beginning), self.nc(idx_from_beginning + 1), kernel_size=3, blur_kernel=[1], downsample=True, pad=0) ) nchannels = self.nc(self.opt.netE_num_downsampling_sp + self.opt.netE_num_downsampling_gl) self.add_module( "ToGlobalCode", nn.Sequential( EqualLinear(nchannels, self.opt.global_code_ch) ) ) def nc(self, idx): nc = self.opt.netE_nc_steepness ** (5 + idx) nc = nc * self.opt.netE_scale_capacity # nc = min(self.opt.global_code_ch, int(round(nc))) return round(nc) def forward(self, x, extract_features=False): x = self.FromRGB(x) midpoint = self.DownToSpatialCode(x) sp = self.ToSpatialCode(midpoint) if extract_features: padded_midpoint = F.pad(midpoint, (1, 0, 1, 0), mode='reflect') feature = self.DownToGlobalCode[0](padded_midpoint) assert feature.size(2) == sp.size(2) // 2 and \ feature.size(3) == sp.size(3) // 2 feature = F.interpolate( feature, size=(7, 7), mode='bilinear', align_corners=False) x = self.DownToGlobalCode(midpoint) x = x.mean(dim=(2, 3)) gl = self.ToGlobalCode(x) sp = util.normalize(sp) gl = util.normalize(gl) if extract_features: return sp, gl, feature else: return sp, gl