import torch from modules.base import BaseModule from modules.linear_modulation import FeatureWiseAffine from modules.interpolation import InterpolationBlock from modules.layers import Conv1dWithInitialization class BasicModulationBlock(BaseModule): """ Linear modulation part of UBlock, represented by sequence of the following layers: - Feature-wise Affine - LReLU - 3x1 Conv """ def __init__(self, n_channels, dilation): super(BasicModulationBlock, self).__init__() self.featurewise_affine = FeatureWiseAffine() self.leaky_relu = torch.nn.LeakyReLU(0.2) self.convolution = Conv1dWithInitialization( in_channels=n_channels, out_channels=n_channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation ) def forward(self, x, scale, shift): outputs = self.featurewise_affine(x, scale, shift) outputs = self.leaky_relu(outputs) outputs = self.convolution(outputs) return outputs class UpsamplingBlock(BaseModule): def __init__(self, in_channels, out_channels, factor, dilations): super(UpsamplingBlock, self).__init__() self.first_block_main_branch = torch.nn.ModuleDict({ 'upsampling': torch.nn.Sequential(*[ torch.nn.LeakyReLU(0.2), InterpolationBlock( scale_factor=factor, mode='linear', align_corners=False ), Conv1dWithInitialization( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=dilations[0], dilation=dilations[0] ), torch.nn.LeakyReLU(0.2) ]), 'modulation': BasicModulationBlock( out_channels, dilation=dilations[1] ) }) self.first_block_residual_branch = torch.nn.Sequential(*[ InterpolationBlock( scale_factor=factor, mode='linear', align_corners=False ), Conv1dWithInitialization( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 ) ]) self.second_block_main_branch = torch.nn.ModuleDict({ f'modulation_{idx}': BasicModulationBlock( out_channels, dilation=dilations[2 + idx] ) for idx in range(2) }) def forward(self, x, scale, shift): # First upsampling residual block outputs = self.first_block_main_branch['upsampling'](x) outputs = self.first_block_main_branch['modulation'](outputs, scale, shift) outputs = outputs + self.first_block_residual_branch(x) # Second residual block residual = self.second_block_main_branch['modulation_0'](outputs, scale, shift) outputs = outputs + self.second_block_main_branch['modulation_1'](residual, scale, shift) return outputs