ethanNeuralImage's picture
models
92ec8d3
raw
history blame
No virus
1.67 kB
import math
import torch
from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
from models.hyperstyle.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE
from models.stylegan2.model import EqualLinear
class WEncoder(Module):
def __init__(self, num_layers, mode='ir', opts=None):
super(WEncoder, self).__init__()
print('Using WEncoder')
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
BatchNorm2d(64),
PReLU(64))
self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = EqualLinear(512, 512, lr_mul=1)
modules = []
for block in blocks:
for bottleneck in block:
modules.append(unit_module(bottleneck.in_channel,
bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
log_size = int(math.log(opts.output_size, 2))
self.style_count = 2 * log_size - 2
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_pool(x)
x = x.view(-1, 512)
x = self.linear(x)
return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)