Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from arcface_torch.backbones.iresnet import iresnet100 | |
from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper | |
class ShapeAwareIdentityExtractor(nn.Module): | |
def __init__(self, identity_extractor_config): | |
""" | |
Shape Aware Identity Extractor | |
Parameters: | |
---------- | |
identity_extractor_config: Dict[str, str] | |
必须包含以下内容: | |
f_3d_checkpoint_path: str | |
3D人脸重建模型路径,如"model/Deep3DFaceRecon_pytorch/checkpoints/epoch_20.pth" | |
f_id_checkpoint_path: str | |
arcface人脸识别模型路径 | |
非官方实现用的是https://onedrive.live.com/?authkey=%21AFZjr283nwZHqbA&id=4A83B6B633B029CC%215585&cid=4A83B6B633B029CC/backbone.pth | |
""" | |
super(ShapeAwareIdentityExtractor, self).__init__() | |
f_3d_checkpoint_path = identity_extractor_config["f_3d_checkpoint_path"] | |
f_id_checkpoint_path = identity_extractor_config["f_id_checkpoint_path"] | |
# 3D人脸重建模型 | |
self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False) | |
self.f_3d.load_state_dict(torch.load(f_3d_checkpoint_path, map_location="cpu")["net_recon"]) | |
self.f_3d.eval() | |
# 人脸识别模型 | |
self.f_id = iresnet100(pretrained=False, fp16=False) | |
self.f_id.load_state_dict(torch.load(f_id_checkpoint_path, map_location="cpu")) | |
self.f_id.eval() | |
def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0): | |
""" | |
插值shape和id信息 | |
""" | |
c_s = self.f_3d(i_source) | |
c_t = self.f_3d(i_target) | |
c_interp = shape_rate * c_s + (1 - shape_rate) * c_t | |
c_fuse = torch.cat((c_interp[:, :80], c_t[:, 80:]), dim=1) | |
# extract source face identity feature | |
v_s = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) | |
v_t = F.normalize(self.f_id(F.interpolate((i_target - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) | |
v_id = id_rate * v_s + (1 - id_rate) * v_t | |
# concat new shape feature and source identity | |
v_sid = torch.cat((c_fuse, v_id), dim=1) | |
return v_sid | |
def forward(self, i_source, i_target): | |
""" | |
Parameters: | |
----------- | |
i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image | |
i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image | |
Returns: | |
-------- | |
v_sid: torch.Tensor, fused shape and id features | |
""" | |
# regress 3DMM coefficients | |
c_s = self.f_3d(i_source) | |
c_t = self.f_3d(i_target) | |
# generate a new 3D face model: source's identity + target's posture and expression | |
# from https://github.com/sicxu/Deep3DFaceRecon_pytorch/blob/f221678d4b49ca35f1275ba60f721ecb38a2cd19/models/networks.py#L85 | |
c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1) | |
# extract source face identity feature | |
v_id = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) | |
# concat new shape feature and source identity | |
v_sid = torch.cat((c_fuse, v_id), dim=1) | |
return v_sid | |