import torch from torch import nn from torch.nn import functional as F from math import sqrt class EqualLR: def __init__(self, name): self.name = name def compute_weight(self, module): weight = getattr(module, self.name + '_orig') fan_in = weight.data.size(1) * weight.data[0][0].numel() return weight * sqrt(2 / fan_in) @staticmethod def apply(module, name): fn = EqualLR(name) weight = getattr(module, name) del module._parameters[name] module.register_parameter(name + '_orig', nn.Parameter(weight.data)) module.register_forward_pre_hook(fn) return fn def __call__(self, module, input): weight = self.compute_weight(module) setattr(module, self.name, weight) def equal_lr(module, name='weight'): EqualLR.apply(module, name) return module class PixelNorm(nn.Module): def __init__(self): super().__init__() def forward(self, input): return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) class EqualConv2d(nn.Module): def __init__(self, *args, **kwargs): super().__init__() conv = nn.Conv2d(*args, **kwargs) conv.weight.data.normal_() conv.bias.data.zero_() self.conv = equal_lr(conv) def forward(self, input): return self.conv(input) class EqualConvTranspose2d(nn.Module): ### additional module for OOGAN usage def __init__(self, *args, **kwargs): super().__init__() conv = nn.ConvTranspose2d(*args, **kwargs) conv.weight.data.normal_() conv.bias.data.zero_() self.conv = equal_lr(conv) def forward(self, input): return self.conv(input) class EqualLinear(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() linear = nn.Linear(in_dim, out_dim) linear.weight.data.normal_() linear.bias.data.zero_() self.linear = equal_lr(linear) def forward(self, input): return self.linear(input) class ConvBlock(nn.Module): def __init__(self, in_channel, out_channel, kernel_size, padding, kernel_size2=None, padding2=None, pixel_norm=True): super().__init__() pad1 = padding pad2 = padding if padding2 is not None: pad2 = padding2 kernel1 = kernel_size kernel2 = kernel_size if kernel_size2 is not None: kernel2 = kernel_size2 convs = [EqualConv2d(in_channel, out_channel, kernel1, padding=pad1)] if pixel_norm: convs.append(PixelNorm()) convs.append(nn.LeakyReLU(0.1)) convs.append(EqualConv2d(out_channel, out_channel, kernel2, padding=pad2)) if pixel_norm: convs.append(PixelNorm()) convs.append(nn.LeakyReLU(0.1)) self.conv = nn.Sequential(*convs) def forward(self, input): out = self.conv(input) return out def upscale(feat): return F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False) class Generator(nn.Module): def __init__(self, input_code_dim=128, in_channel=128, pixel_norm=True, tanh=True): super().__init__() self.input_dim = input_code_dim self.tanh = tanh self.input_layer = nn.Sequential( EqualConvTranspose2d(input_code_dim, in_channel, 4, 1, 0), PixelNorm(), nn.LeakyReLU(0.1)) self.progression_4 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) self.progression_8 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) self.progression_16 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) self.progression_32 = ConvBlock(in_channel, in_channel, 3, 1, pixel_norm=pixel_norm) self.progression_64 = ConvBlock(in_channel, in_channel//2, 3, 1, pixel_norm=pixel_norm) self.progression_128 = ConvBlock(in_channel//2, in_channel//4, 3, 1, pixel_norm=pixel_norm) self.progression_256 = ConvBlock(in_channel//4, in_channel//4, 3, 1, pixel_norm=pixel_norm) self.to_rgb_8 = EqualConv2d(in_channel, 3, 1) self.to_rgb_16 = EqualConv2d(in_channel, 3, 1) self.to_rgb_32 = EqualConv2d(in_channel, 3, 1) self.to_rgb_64 = EqualConv2d(in_channel//2, 3, 1) self.to_rgb_128 = EqualConv2d(in_channel//4, 3, 1) self.to_rgb_256 = EqualConv2d(in_channel//4, 3, 1) self.max_step = 6 def progress(self, feat, module): out = F.interpolate(feat, scale_factor=2, mode='bilinear', align_corners=False) out = module(out) return out def output(self, feat1, feat2, module1, module2, alpha): if 0 <= alpha < 1: skip_rgb = upscale(module1(feat1)) out = (1-alpha)*skip_rgb + alpha*module2(feat2) else: out = module2(feat2) if self.tanh: return torch.tanh(out) return out def forward(self, input, step=0, alpha=-1): if step > self.max_step: step = self.max_step out_4 = self.input_layer(input.view(-1, self.input_dim, 1, 1)) out_4 = self.progression_4(out_4) out_8 = self.progress(out_4, self.progression_8) if step==1: if self.tanh: return torch.tanh(self.to_rgb_8(out_8)) return self.to_rgb_8(out_8) out_16 = self.progress(out_8, self.progression_16) if step==2: return self.output( out_8, out_16, self.to_rgb_8, self.to_rgb_16, alpha ) out_32 = self.progress(out_16, self.progression_32) if step==3: return self.output( out_16, out_32, self.to_rgb_16, self.to_rgb_32, alpha ) out_64 = self.progress(out_32, self.progression_64) if step==4: return self.output( out_32, out_64, self.to_rgb_32, self.to_rgb_64, alpha ) out_128 = self.progress(out_64, self.progression_128) if step==5: return self.output( out_64, out_128, self.to_rgb_64, self.to_rgb_128, alpha ) out_256 = self.progress(out_128, self.progression_256) if step==6: return self.output( out_128, out_256, self.to_rgb_128, self.to_rgb_256, alpha ) class Discriminator(nn.Module): def __init__(self, feat_dim=128): super().__init__() self.progression = nn.ModuleList([ConvBlock(feat_dim//4, feat_dim//4, 3, 1), ConvBlock(feat_dim//4, feat_dim//2, 3, 1), ConvBlock(feat_dim//2, feat_dim, 3, 1), ConvBlock(feat_dim, feat_dim, 3, 1), ConvBlock(feat_dim, feat_dim, 3, 1), ConvBlock(feat_dim, feat_dim, 3, 1), ConvBlock(feat_dim+1, feat_dim, 3, 1, 4, 0)]) self.from_rgb = nn.ModuleList([EqualConv2d(3, feat_dim//4, 1), EqualConv2d(3, feat_dim//4, 1), EqualConv2d(3, feat_dim//2, 1), EqualConv2d(3, feat_dim, 1), EqualConv2d(3, feat_dim, 1), EqualConv2d(3, feat_dim, 1), EqualConv2d(3, feat_dim, 1)]) self.n_layer = len(self.progression) self.linear = EqualLinear(feat_dim, 1) def forward(self, input, step=0, alpha=-1): for i in range(step, -1, -1): index = self.n_layer - i - 1 if i == step: out = self.from_rgb[index](input) if i == 0: out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8) mean_std = out_std.mean() mean_std = mean_std.expand(out.size(0), 1, 4, 4) out = torch.cat([out, mean_std], 1) out = self.progression[index](out) if i > 0: # out = F.avg_pool2d(out, 2) out = F.interpolate(out, scale_factor=0.5, mode='bilinear', align_corners=False) if i == step and 0 <= alpha < 1: # skip_rgb = F.avg_pool2d(input, 2) skip_rgb = F.interpolate(input, scale_factor=0.5, mode='bilinear', align_corners=False) skip_rgb = self.from_rgb[index + 1](skip_rgb) out = (1 - alpha) * skip_rgb + alpha * out out = out.squeeze(2).squeeze(2) # print(input.size(), out.size(), step) out = self.linear(out) return out