HiFiFace-inference-demo / models /shape_aware_identity_model.py
xuehongyang
ser
83d8d3c
import torch
import torch.nn as nn
import torch.nn.functional as F
from arcface_torch.backbones.iresnet import iresnet100
from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper
class ShapeAwareIdentityExtractor(nn.Module):
def __init__(self, identity_extractor_config):
"""
Shape Aware Identity Extractor
Parameters:
----------
identity_extractor_config: Dict[str, str]
必须包含以下内容:
f_3d_checkpoint_path: str
3D人脸重建模型路径,如"model/Deep3DFaceRecon_pytorch/checkpoints/epoch_20.pth"
f_id_checkpoint_path: str
arcface人脸识别模型路径
非官方实现用的是https://onedrive.live.com/?authkey=%21AFZjr283nwZHqbA&id=4A83B6B633B029CC%215585&cid=4A83B6B633B029CC/backbone.pth
"""
super(ShapeAwareIdentityExtractor, self).__init__()
f_3d_checkpoint_path = identity_extractor_config["f_3d_checkpoint_path"]
f_id_checkpoint_path = identity_extractor_config["f_id_checkpoint_path"]
# 3D人脸重建模型
self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False)
self.f_3d.load_state_dict(torch.load(f_3d_checkpoint_path, map_location="cpu")["net_recon"])
self.f_3d.eval()
# 人脸识别模型
self.f_id = iresnet100(pretrained=False, fp16=False)
self.f_id.load_state_dict(torch.load(f_id_checkpoint_path, map_location="cpu"))
self.f_id.eval()
@torch.no_grad()
def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0):
"""
插值shape和id信息
"""
c_s = self.f_3d(i_source)
c_t = self.f_3d(i_target)
c_interp = shape_rate * c_s + (1 - shape_rate) * c_t
c_fuse = torch.cat((c_interp[:, :80], c_t[:, 80:]), dim=1)
# extract source face identity feature
v_s = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2)
v_t = F.normalize(self.f_id(F.interpolate((i_target - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2)
v_id = id_rate * v_s + (1 - id_rate) * v_t
# concat new shape feature and source identity
v_sid = torch.cat((c_fuse, v_id), dim=1)
return v_sid
def forward(self, i_source, i_target):
"""
Parameters:
-----------
i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image
i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image
Returns:
--------
v_sid: torch.Tensor, fused shape and id features
"""
# regress 3DMM coefficients
c_s = self.f_3d(i_source)
c_t = self.f_3d(i_target)
# generate a new 3D face model: source's identity + target's posture and expression
# from https://github.com/sicxu/Deep3DFaceRecon_pytorch/blob/f221678d4b49ca35f1275ba60f721ecb38a2cd19/models/networks.py#L85
c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1)
# extract source face identity feature
v_id = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2)
# concat new shape feature and source identity
v_sid = torch.cat((c_fuse, v_id), dim=1)
return v_sid