|
|
|
|
|
"""
|
|
Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image.
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from .util import SPADEResnetBlock
|
|
|
|
|
|
class SPADEDecoder(nn.Module):
|
|
def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2):
|
|
for i in range(num_down_blocks):
|
|
input_channels = min(max_features, block_expansion * (2 ** (i + 1)))
|
|
self.upscale = upscale
|
|
super().__init__()
|
|
norm_G = 'spadespectralinstance'
|
|
label_num_channels = input_channels
|
|
|
|
self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1)
|
|
self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
|
self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
|
self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
|
self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
|
self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
|
self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
|
|
self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels)
|
|
self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels)
|
|
self.up = nn.Upsample(scale_factor=2)
|
|
|
|
if self.upscale is None or self.upscale <= 1:
|
|
self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1)
|
|
else:
|
|
self.conv_img = nn.Sequential(
|
|
nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
|
|
nn.PixelShuffle(upscale_factor=2)
|
|
)
|
|
|
|
def forward(self, feature):
|
|
seg = feature
|
|
x = self.fc(feature)
|
|
x = self.G_middle_0(x, seg)
|
|
x = self.G_middle_1(x, seg)
|
|
x = self.G_middle_2(x, seg)
|
|
x = self.G_middle_3(x, seg)
|
|
x = self.G_middle_4(x, seg)
|
|
x = self.G_middle_5(x, seg)
|
|
|
|
x = self.up(x)
|
|
x = self.up_0(x, seg)
|
|
x = self.up(x)
|
|
x = self.up_1(x, seg)
|
|
|
|
x = self.conv_img(F.leaky_relu(x, 2e-1))
|
|
x = torch.sigmoid(x)
|
|
|
|
return x |