import torch.nn as nn import torch.nn.functional as F class ResBlock(nn.Module): def __init__(self, in_channel, out_channel, down_sample=False, up_sample=False, norm=True): super(ResBlock, self).__init__() main_module_list = [] if norm: main_module_list += [ nn.InstanceNorm2d(in_channel), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), ] else: main_module_list += [ nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), ] if down_sample: main_module_list.append(nn.AvgPool2d(kernel_size=2)) elif up_sample: main_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear")) if norm: main_module_list += [ nn.InstanceNorm2d(out_channel), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1), ] else: main_module_list += [ nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1), ] self.main_path = nn.Sequential(*main_module_list) side_module_list = [nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)] if down_sample: side_module_list.append(nn.AvgPool2d(kernel_size=2)) elif up_sample: side_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear")) self.side_path = nn.Sequential(*side_module_list) def forward(self, x): x1 = self.main_path(x) x2 = self.side_path(x) return x1 + x2 class AdaIn(nn.Module): def __init__(self, in_channel, vector_size): super(AdaIn, self).__init__() self.eps = 1e-5 self.std_style_fc = nn.Linear(vector_size, in_channel) self.mean_style_fc = nn.Linear(vector_size, in_channel) def forward(self, x, style_vector): std_style = self.std_style_fc(style_vector) mean_style = self.mean_style_fc(style_vector) std_style = std_style.unsqueeze(-1).unsqueeze(-1) mean_style = mean_style.unsqueeze(-1).unsqueeze(-1) x = F.instance_norm(x) x = std_style * x + mean_style return x class AdaInResBlock(nn.Module): def __init__(self, in_channel, out_channel, up_sample=False): super(AdaInResBlock, self).__init__() self.vector_size = 257 + 512 self.up_sample = up_sample self.adain1 = AdaIn(in_channel, self.vector_size) self.adain2 = AdaIn(out_channel, self.vector_size) main_module_list = [] main_module_list += [ nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), ] if up_sample: main_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear")) self.main_path1 = nn.Sequential(*main_module_list) self.main_path2 = nn.Sequential( nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1), ) side_module_list = [nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)] if up_sample: side_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear")) self.side_path = nn.Sequential(*side_module_list) def forward(self, x, id_vector): x1 = self.adain1(x, id_vector) x1 = self.main_path1(x1) x2 = self.side_path(x) x1 = self.adain2(x1, id_vector) x1 = self.main_path2(x1) return x1 + x2 class UpSamplingBlock(nn.Module): def __init__( self, ): super(UpSamplingBlock, self).__init__() self.net = nn.Sequential(ResBlock(256, 256, up_sample=True), ResBlock(256, 256, up_sample=True)) self.i_r_net = nn.Sequential(nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 3, 3, 1, 1)) self.m_r_net = nn.Sequential(nn.Conv2d(256, 1, 3, 1, 1), nn.Sigmoid()) def forward(self, x): x = self.net(x) i_r = self.i_r_net(x) m_r = self.m_r_net(x) return i_r, m_r