import numpy as np from torch import nn from torch.nn import Conv2d, Sequential, Module from models.hyperstyle.encoders.helpers import SeparableBlock from models.stylegan2.model import EqualLinear # layer_idx: [kernel_size, in_channels, out_channels] PARAMETERS = { 0: [3, 512, 512], 1: [1, 512, 3], 2: [3, 512, 512], 3: [3, 512, 512], 4: [1, 512, 3], 5: [3, 512, 512], 6: [3, 512, 512], 7: [1, 512, 3], 8: [3, 512, 512], 9: [3, 512, 512], 10: [1, 512, 3], 11: [3, 512, 512], 12: [3, 512, 512], 13: [1, 512, 3], 14: [3, 512, 256], 15: [3, 256, 256], 16: [1, 256, 3], 17: [3, 256, 128], 18: [3, 128, 128], 19: [1, 128, 3], 20: [3, 128, 64], 21: [3, 64, 64], 22: [1, 64, 3], 23: [3, 64, 32], 24: [3, 32, 32], 25: [1, 32, 3] } TO_RGB_LAYERS = [1, 4, 7, 10, 13, 16, 19, 22, 25] class RefinementBlock(Module): def __init__(self, layer_idx, opts, n_channels=512, inner_c=256, spatial=16): super(RefinementBlock, self).__init__() self.layer_idx = layer_idx self.opts = opts self.kernel_size, self.in_channels, self.out_channels = PARAMETERS[self.layer_idx] self.spatial = spatial self.n_channels = n_channels self.inner_c = inner_c self.out_c = 512 num_pools = int(np.log2(self.spatial)) - 1 if self.kernel_size == 3: num_pools = num_pools - 1 self.modules = [] self.modules += [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] for i in range(num_pools - 1): self.modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] self.modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] self.convs = nn.Sequential(*self.modules) if layer_idx in TO_RGB_LAYERS: self.output = Sequential( Conv2d(self.out_c, self.in_channels * self.out_channels, kernel_size=1, stride=1, padding=0)) else: self.output = Sequential(nn.AdaptiveAvgPool2d((1, 1)), Conv2d(self.out_c, self.in_channels * self.out_channels, kernel_size=1, stride=1, padding=0)) def forward(self, x): x = self.convs(x) x = self.output(x) if self.layer_idx in TO_RGB_LAYERS: x = x.view(-1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size) else: x = x.view(-1, self.out_channels, self.in_channels) x = x.unsqueeze(3).repeat(1, 1, 1, self.kernel_size).unsqueeze(4).repeat(1, 1, 1, 1, self.kernel_size) return x class HyperRefinementBlock(Module): def __init__(self, hypernet, n_channels=512, inner_c=128, spatial=16): super(HyperRefinementBlock, self).__init__() self.n_channels = n_channels self.inner_c = inner_c self.out_c = 512 num_pools = int(np.log2(spatial)) modules = [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=1, padding=1), nn.LeakyReLU()] for i in range(num_pools - 1): modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] self.convs = nn.Sequential(*modules) self.linear = EqualLinear(self.out_c, self.out_c, lr_mul=1) self.hypernet = hypernet def forward(self, features): code = self.convs(features) code = code.view(-1, self.out_c) code = self.linear(code) weight_delta = self.hypernet(code) return weight_delta class RefinementBlockSeparable(Module): def __init__(self, layer_idx, opts, n_channels=512, inner_c=256, spatial=16): super(RefinementBlockSeparable, self).__init__() self.layer_idx = layer_idx self.kernel_size, self.in_channels, self.out_channels = PARAMETERS[self.layer_idx] self.spatial = spatial self.n_channels = n_channels self.inner_c = inner_c self.out_c = 512 num_pools = int(np.log2(self.spatial)) - 1 self.modules = [] self.modules += [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] for i in range(num_pools - 1): self.modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] self.modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] self.convs = nn.Sequential(*self.modules) self.opts = opts if self.layer_idx in TO_RGB_LAYERS: self.output = Sequential(Conv2d(self.out_c, self.in_channels * self.out_channels, kernel_size=1, stride=1, padding=0)) else: self.output = Sequential(SeparableBlock(input_size=self.out_c, kernel_channels_in=self.in_channels, kernel_channels_out=self.out_channels, kernel_size=self.kernel_size)) def forward(self, x): x = self.convs(x) x = self.output(x) if self.layer_idx in TO_RGB_LAYERS: x = x.view(-1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size) return x