Spaces:
Runtime error
Runtime error
File size: 1,671 Bytes
92ec8d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import math
import torch
from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
from models.hyperstyle.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE
from models.stylegan2.model import EqualLinear
class WEncoder(Module):
def __init__(self, num_layers, mode='ir', opts=None):
super(WEncoder, self).__init__()
print('Using WEncoder')
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
BatchNorm2d(64),
PReLU(64))
self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = EqualLinear(512, 512, lr_mul=1)
modules = []
for block in blocks:
for bottleneck in block:
modules.append(unit_module(bottleneck.in_channel,
bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
log_size = int(math.log(opts.output_size, 2))
self.style_count = 2 * log_size - 2
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_pool(x)
x = x.view(-1, 512)
x = self.linear(x)
return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)
|