HiFiFace-inference-demo / models /model_blocks.py
xuehongyang
ser
83d8d3c
import torch.nn as nn
import torch.nn.functional as F
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, down_sample=False, up_sample=False, norm=True):
super(ResBlock, self).__init__()
main_module_list = []
if norm:
main_module_list += [
nn.InstanceNorm2d(in_channel),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
]
else:
main_module_list += [
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
]
if down_sample:
main_module_list.append(nn.AvgPool2d(kernel_size=2))
elif up_sample:
main_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
if norm:
main_module_list += [
nn.InstanceNorm2d(out_channel),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
]
else:
main_module_list += [
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
]
self.main_path = nn.Sequential(*main_module_list)
side_module_list = [nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)]
if down_sample:
side_module_list.append(nn.AvgPool2d(kernel_size=2))
elif up_sample:
side_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
self.side_path = nn.Sequential(*side_module_list)
def forward(self, x):
x1 = self.main_path(x)
x2 = self.side_path(x)
return x1 + x2
class AdaIn(nn.Module):
def __init__(self, in_channel, vector_size):
super(AdaIn, self).__init__()
self.eps = 1e-5
self.std_style_fc = nn.Linear(vector_size, in_channel)
self.mean_style_fc = nn.Linear(vector_size, in_channel)
def forward(self, x, style_vector):
std_style = self.std_style_fc(style_vector)
mean_style = self.mean_style_fc(style_vector)
std_style = std_style.unsqueeze(-1).unsqueeze(-1)
mean_style = mean_style.unsqueeze(-1).unsqueeze(-1)
x = F.instance_norm(x)
x = std_style * x + mean_style
return x
class AdaInResBlock(nn.Module):
def __init__(self, in_channel, out_channel, up_sample=False):
super(AdaInResBlock, self).__init__()
self.vector_size = 257 + 512
self.up_sample = up_sample
self.adain1 = AdaIn(in_channel, self.vector_size)
self.adain2 = AdaIn(out_channel, self.vector_size)
main_module_list = []
main_module_list += [
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
]
if up_sample:
main_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
self.main_path1 = nn.Sequential(*main_module_list)
self.main_path2 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
)
side_module_list = [nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)]
if up_sample:
side_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
self.side_path = nn.Sequential(*side_module_list)
def forward(self, x, id_vector):
x1 = self.adain1(x, id_vector)
x1 = self.main_path1(x1)
x2 = self.side_path(x)
x1 = self.adain2(x1, id_vector)
x1 = self.main_path2(x1)
return x1 + x2
class UpSamplingBlock(nn.Module):
def __init__(
self,
):
super(UpSamplingBlock, self).__init__()
self.net = nn.Sequential(ResBlock(256, 256, up_sample=True), ResBlock(256, 256, up_sample=True))
self.i_r_net = nn.Sequential(nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 3, 3, 1, 1))
self.m_r_net = nn.Sequential(nn.Conv2d(256, 1, 3, 1, 1), nn.Sigmoid())
def forward(self, x):
x = self.net(x)
i_r = self.i_r_net(x)
m_r = self.m_r_net(x)
return i_r, m_r