| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from models.model_blocks import AdaInResBlock |
| from models.model_blocks import ResBlock |
| from models.model_blocks import UpSamplingBlock |
|
|
|
|
| class SemanticFaceFusionModule(nn.Module): |
| def __init__(self): |
| """ |
| Semantic Face Fusion Module |
| to preserve lighting and background |
| """ |
| super(SemanticFaceFusionModule, self).__init__() |
|
|
| self.sigma = ResBlock(256, 256) |
| self.low_mask_predict = nn.Sequential(nn.Conv2d(256, 1, 3, 1, 1), nn.Sigmoid()) |
| self.z_fuse_block_1 = AdaInResBlock(256, 256) |
| self.z_fuse_block_2 = AdaInResBlock(256, 256) |
|
|
| self.i_low_block = nn.Sequential(nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 3, 3, 1, 1)) |
|
|
| self.f_up = UpSamplingBlock() |
|
|
| def forward(self, target_image, z_enc, z_dec, v_sid): |
| """ |
| Parameters: |
| ---------- |
| target_image: 目标脸图片 |
| z_enc: 1/4原图大小的low-level encoder feature map |
| z_dec: 1/4原图大小的low-level decoder feature map |
| v_sid: the 3D shape aware identity vector |
| |
| Returns: |
| -------- |
| i_r: re-target image |
| i_low: 1/4 size retarget image |
| m_r: face mask |
| m_low: 1/4 size face mask |
| """ |
| z_enc = self.sigma(z_enc) |
|
|
| |
| m_low = self.low_mask_predict(z_dec) |
|
|
| |
| |
| z_fuse = m_low * z_dec + (1 - m_low) * z_enc |
|
|
| z_fuse = self.z_fuse_block_1(z_fuse, v_sid) |
| z_fuse = self.z_fuse_block_2(z_fuse, v_sid) |
|
|
| i_low = self.i_low_block(z_fuse) |
|
|
| i_low = m_low * i_low + (1 - m_low) * F.interpolate(target_image, scale_factor=0.25) |
|
|
| i_r, m_r = self.f_up(z_fuse) |
| i_r = m_r * i_r + (1 - m_r) * target_image |
|
|
| return i_r, i_low, m_r, m_low |
|
|
|
|
| if __name__ == "__main__": |
| import torch |
|
|
| timg = torch.randn(1, 3, 256, 256) |
| z_enc = torch.randn(1, 256, 64, 64) |
| z_dec = torch.randn(1, 256, 64, 64) |
| v_sid = torch.randn(1, 769) |
| model = SemanticFaceFusionModule() |
| i_r, i_low, m_r, m_low = model(timg, z_enc, z_dec, v_sid) |
| print(i_r.shape, i_low.shape, m_r.shape, m_low.shape) |
|
|