deepkyu's picture
initial commit
1ba3df3
raw
history blame
No virus
1.53 kB
import torch
import torch.nn as nn
from models.module import ResidualBlocks
_DECODER_CHANNEL_DEFAULT = 512
class Decoder(nn.Module):
def __init__(self, hp, in_channels=_DECODER_CHANNEL_DEFAULT, out_channels=1):
super().__init__()
self.module = nn.ModuleList()
def forward(self, x):
for block in self.module:
x = block(x)
return x
class VanillaDecoder(Decoder):
def __init__(self, hp, in_channels, out_channels):
super().__init__(hp, in_channels, out_channels)
self.depth = hp.decoder.depth
self.blocks = hp.decoder.residual_blocks
self.module = nn.ModuleList()
if self.blocks > 0:
self.module.append(ResidualBlocks(in_channels, n_blocks=self.blocks))
for layer_idx in range(1, self.depth + 1): # add upsampling layers
self.module.append(nn.Sequential(
nn.ConvTranspose2d(in_channels // (2 ** (layer_idx - 1)),
in_channels // (2 ** layer_idx),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=False),
nn.BatchNorm2d(in_channels // (2 ** layer_idx)),
nn.ReLU(True)
))
final = nn.Sequential(
nn.Conv2d(in_channels // (2 ** self.depth), out_channels, kernel_size=7, padding=3, padding_mode='reflect'),
nn.Tanh()
)
self.module.append(final)