import math import torch from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module from models.hyperstyle.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE from models.stylegan2.model import EqualLinear class WEncoder(Module): def __init__(self, num_layers, mode='ir', opts=None): super(WEncoder, self).__init__() print('Using WEncoder') assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' blocks = get_blocks(num_layers) if mode == 'ir': unit_module = bottleneck_IR elif mode == 'ir_se': unit_module = bottleneck_IR_SE self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) self.linear = EqualLinear(512, 512, lr_mul=1) modules = [] for block in blocks: for bottleneck in block: modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) log_size = int(math.log(opts.output_size, 2)) self.style_count = 2 * log_size - 2 def forward(self, x): x = self.input_layer(x) x = self.body(x) x = self.output_pool(x) x = x.view(-1, 512) x = self.linear(x) return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)