Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from models.module import Conv2d, StyleAttentionBlock | |
_ENCODER_CHANNEL_DEFAULT = 256 | |
class Encoder(nn.Module): | |
def __init__(self, hp, in_channels=1, out_channels=_ENCODER_CHANNEL_DEFAULT): | |
super().__init__() | |
self.hp = hp | |
self.module = nn.ModuleList() | |
def forward(self, x): | |
for block in self.module: | |
x = block(x) | |
return x | |
class ContentVanillaEncoder(Encoder): | |
def __init__(self, hp, in_channels, out_channels): | |
super().__init__(hp, in_channels, out_channels) | |
self.depth = hp.encoder.content.depth | |
assert out_channels // (2 ** self.depth) >= in_channels * 2, "Output channel should be increased" | |
self.module = nn.ModuleList() | |
self.module.append( | |
Conv2d(in_channels, out_channels // (2 ** self.depth), | |
kernel_size=7, padding=3, padding_mode='reflect', bias=False) | |
) | |
for layer_idx in range(1, self.depth + 1): # downsample | |
self.module.append( | |
Conv2d(out_channels // (2 ** (self.depth - layer_idx + 1)), | |
out_channels // (2 ** (self.depth - layer_idx)), | |
kernel_size=3, stride=2, padding=1, bias=False) | |
) | |
class StyleVanillaEncoder(Encoder): | |
def __init__(self, hp, in_channels, out_channels): | |
super().__init__(hp, in_channels, out_channels) | |
self.depth = hp.encoder.style.depth | |
assert out_channels // (2 ** self.depth) >= in_channels * 2, "Output channel should be increased" | |
encoder_module = [] | |
encoder_module.append( | |
Conv2d(in_channels, out_channels // (2 ** self.depth), | |
kernel_size=7, padding=3, padding_mode='reflect', bias=False) | |
) | |
for layer_idx in range(1, self.depth + 1): # downsample | |
encoder_module.append( | |
Conv2d(out_channels // (2 ** (self.depth - layer_idx + 1)), | |
out_channels // (2 ** (self.depth - layer_idx)), | |
kernel_size=3, stride=2, padding=1, bias=False) | |
) | |
self.add_module("encoder_module", nn.Sequential(*encoder_module)) | |
self.add_module("attention_module", StyleAttentionBlock(out_channels)) | |
def forward(self, x): | |
B, K, H, W = x.size() | |
out = self.encoder_module(x.view(-1, 1, H, W)) | |
out = self.attention_module(out, B, K) | |
return out | |