|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from torch.nn.utils.spectral_norm import spectral_norm | 
					
						
						|  |  | 
					
						
						|  | from basicsr.utils.registry import ARCH_REGISTRY | 
					
						
						|  | from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization | 
					
						
						|  | from .vgg_arch import VGGFeatureExtractor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SFTUpBlock(nn.Module): | 
					
						
						|  | """Spatial feature transform (SFT) with upsampling block. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | in_channel (int): Number of input channels. | 
					
						
						|  | out_channel (int): Number of output channels. | 
					
						
						|  | kernel_size (int): Kernel size in convolutions. Default: 3. | 
					
						
						|  | padding (int): Padding in convolutions. Default: 1. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, in_channel, out_channel, kernel_size=3, padding=1): | 
					
						
						|  | super(SFTUpBlock, self).__init__() | 
					
						
						|  | self.conv1 = nn.Sequential( | 
					
						
						|  | Blur(in_channel), | 
					
						
						|  | spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)), | 
					
						
						|  | nn.LeakyReLU(0.04, True), | 
					
						
						|  |  | 
					
						
						|  | ) | 
					
						
						|  | self.convup = nn.Sequential( | 
					
						
						|  | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | 
					
						
						|  | spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), | 
					
						
						|  | nn.LeakyReLU(0.2, True), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.scale_block = nn.Sequential( | 
					
						
						|  | spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), | 
					
						
						|  | spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))) | 
					
						
						|  | self.shift_block = nn.Sequential( | 
					
						
						|  | spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), | 
					
						
						|  | spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, updated_feat): | 
					
						
						|  | out = self.conv1(x) | 
					
						
						|  |  | 
					
						
						|  | scale = self.scale_block(updated_feat) | 
					
						
						|  | shift = self.shift_block(updated_feat) | 
					
						
						|  | out = out * scale + shift | 
					
						
						|  |  | 
					
						
						|  | out = self.convup(out) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @ARCH_REGISTRY.register() | 
					
						
						|  | class DFDNet(nn.Module): | 
					
						
						|  | """DFDNet: Deep Face Dictionary Network. | 
					
						
						|  |  | 
					
						
						|  | It only processes faces with 512x512 size. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | num_feat (int): Number of feature channels. | 
					
						
						|  | dict_path (str): Path to the facial component dictionary. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, num_feat, dict_path): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.parts = ['left_eye', 'right_eye', 'nose', 'mouth'] | 
					
						
						|  |  | 
					
						
						|  | channel_sizes = [128, 256, 512, 512] | 
					
						
						|  | self.feature_sizes = np.array([256, 128, 64, 32]) | 
					
						
						|  | self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4'] | 
					
						
						|  | self.flag_dict_device = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.dict = torch.load(dict_path) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.vgg_extractor = VGGFeatureExtractor( | 
					
						
						|  | layer_name_list=self.vgg_layers, | 
					
						
						|  | vgg_type='vgg19', | 
					
						
						|  | use_input_norm=True, | 
					
						
						|  | range_norm=True, | 
					
						
						|  | requires_grad=False) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.attn_blocks = nn.ModuleDict() | 
					
						
						|  | for idx, feat_size in enumerate(self.feature_sizes): | 
					
						
						|  | for name in self.parts: | 
					
						
						|  | self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8) | 
					
						
						|  | self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4) | 
					
						
						|  | self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2) | 
					
						
						|  | self.upsample3 = SFTUpBlock(num_feat * 2, num_feat) | 
					
						
						|  | self.upsample4 = nn.Sequential( | 
					
						
						|  | spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat), | 
					
						
						|  | UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh()) | 
					
						
						|  |  | 
					
						
						|  | def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size): | 
					
						
						|  | """swap the features from the dictionary.""" | 
					
						
						|  |  | 
					
						
						|  | part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone() | 
					
						
						|  |  | 
					
						
						|  | part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False) | 
					
						
						|  |  | 
					
						
						|  | dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat) | 
					
						
						|  |  | 
					
						
						|  | similarity_score = F.conv2d(part_resize_feat, dict_feat) | 
					
						
						|  | similarity_score = F.softmax(similarity_score.view(-1), dim=0) | 
					
						
						|  |  | 
					
						
						|  | select_idx = torch.argmax(similarity_score) | 
					
						
						|  | swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4]) | 
					
						
						|  |  | 
					
						
						|  | attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat) | 
					
						
						|  | attn_feat = attn * swap_feat | 
					
						
						|  |  | 
					
						
						|  | updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat | 
					
						
						|  | return updated_feat | 
					
						
						|  |  | 
					
						
						|  | def put_dict_to_device(self, x): | 
					
						
						|  | if self.flag_dict_device is False: | 
					
						
						|  | for k, v in self.dict.items(): | 
					
						
						|  | for kk, vv in v.items(): | 
					
						
						|  | self.dict[k][kk] = vv.to(x) | 
					
						
						|  | self.flag_dict_device = True | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, part_locations): | 
					
						
						|  | """ | 
					
						
						|  | Now only support testing with batch size = 0. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | x (Tensor): Input faces with shape (b, c, 512, 512). | 
					
						
						|  | part_locations (list[Tensor]): Part locations. | 
					
						
						|  | """ | 
					
						
						|  | self.put_dict_to_device(x) | 
					
						
						|  |  | 
					
						
						|  | vgg_features = self.vgg_extractor(x) | 
					
						
						|  |  | 
					
						
						|  | updated_vgg_features = [] | 
					
						
						|  | batch = 0 | 
					
						
						|  | for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes): | 
					
						
						|  | dict_features = self.dict[f'{f_size}'] | 
					
						
						|  | vgg_feat = vgg_features[vgg_layer] | 
					
						
						|  | updated_feat = vgg_feat.clone() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for part_idx, part_name in enumerate(self.parts): | 
					
						
						|  | location = (part_locations[part_idx][batch] // (512 / f_size)).int() | 
					
						
						|  | updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name, | 
					
						
						|  | f_size) | 
					
						
						|  |  | 
					
						
						|  | updated_vgg_features.append(updated_feat) | 
					
						
						|  |  | 
					
						
						|  | vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4']) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3]) | 
					
						
						|  | upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2]) | 
					
						
						|  | upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1]) | 
					
						
						|  | upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0]) | 
					
						
						|  | out = self.upsample4(upsampled_feat) | 
					
						
						|  |  | 
					
						
						|  | return out | 
					
						
						|  |  |