Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import cv2 | |
import gradio as gr | |
import kornia | |
import numpy as np | |
import torch | |
from loguru import logger | |
from benchmark.face_pipeline import alignFace | |
from benchmark.face_pipeline import FaceDetector | |
from benchmark.face_pipeline import inverse_transform_batch | |
from benchmark.face_pipeline import SoftErosion | |
from configs.train_config import TrainConfig | |
from models.model import HifiFace | |
class ImageSwap: | |
def __init__(self, cfg, model=None): | |
self.device = cfg.device | |
self.facedetector = FaceDetector(cfg.face_detector_weights, device=self.device) | |
self.alignface = alignFace() | |
opt = TrainConfig() | |
opt.use_ddp = False | |
checkpoint = (cfg.model_path, cfg.model_idx) | |
if model is None: | |
self.model = HifiFace( | |
opt.identity_extractor_config, is_training=False, device=self.device, load_checkpoint=checkpoint | |
) | |
else: | |
self.model = model | |
self.model.eval() | |
self.smooth_mask = SoftErosion(kernel_size=7, threshold=0.9, iterations=7).to(self.device) | |
def _geometry_transfrom_warp_affine(self, swapped_image, inv_att_transforms, frame_size, square_mask): | |
swapped_image = kornia.geometry.transform.warp_affine( | |
swapped_image, | |
inv_att_transforms, | |
frame_size, | |
mode="bilinear", | |
padding_mode="border", | |
align_corners=True, | |
fill_value=torch.zeros(3), | |
) | |
square_mask = kornia.geometry.transform.warp_affine( | |
square_mask, | |
inv_att_transforms, | |
frame_size, | |
mode="bilinear", | |
padding_mode="zeros", | |
align_corners=True, | |
fill_value=torch.zeros(3), | |
) | |
return swapped_image, square_mask | |
def detect_and_align(self, image): | |
detection = self.facedetector(image) | |
if detection.score is None: | |
self.kps_window = [] | |
return None, None | |
max_score_ind = np.argmax(detection.score, axis=0) | |
kps = detection.key_points[max_score_ind] | |
align_img, warp_mat = self.alignface.align_face(image, kps, 256) | |
align_img = cv2.resize(align_img, (256, 256)) | |
align_img = align_img.transpose(2, 0, 1) | |
align_img = torch.from_numpy(align_img).unsqueeze(0).to(self.device).float() | |
align_img = align_img / 255.0 | |
return align_img, warp_mat | |
def inference(self, source_face, target_face, shape_rate, id_rate, iterations=1): | |
src = source_face | |
src, _ = self.detect_and_align(src) | |
if src is None: | |
print("no face in src_img") | |
return | |
target = target_face | |
align_target, warp_mat = self.detect_and_align(target) | |
if align_target is None: | |
print("no face in target_img") | |
return | |
logger.info("start swapping") | |
frame_size = (target.shape[0], target.shape[1]) | |
with torch.no_grad(): | |
for _ in range(iterations): | |
swapped_face, m_r = self.model.forward(src, align_target, shape_rate, id_rate) | |
swapped_face = torch.clamp(swapped_face, 0, 1) | |
align_target = swapped_face | |
smooth_face_mask, _ = self.smooth_mask(m_r) | |
warp_mat = torch.from_numpy(warp_mat).float().unsqueeze(0) | |
inverse_warp_mat = inverse_transform_batch(warp_mat, device=self.device) | |
swapped_face, smooth_face_mask = self._geometry_transfrom_warp_affine( | |
swapped_face, inverse_warp_mat, frame_size, smooth_face_mask | |
) | |
target = torch.from_numpy(target.transpose(2, 0, 1)).unsqueeze(0).to(self.device).float() / 255.0 | |
result_face = (1 - smooth_face_mask) * target + smooth_face_mask * swapped_face | |
result_face = torch.clamp(result_face * 255.0, 0.0, 255.0, out=None).type(dtype=torch.uint8) | |
result_face = result_face.detach().cpu().numpy() | |
img = result_face.transpose(0, 2, 3, 1)[0] | |
return img | |
class ConfigPath: | |
face_detector_weights = "/data/useful_ckpt/face_detector/face_detector_scrfd_10g_bnkps.onnx" | |
model_path = "" | |
model_idx = 80000 | |
device = "cuda" | |
def main(): | |
cfg = ConfigPath() | |
parser = argparse.ArgumentParser( | |
prog="benchmark", description="What the program does", epilog="Text at the bottom of help" | |
) | |
parser.add_argument("-m", "--model_path") | |
parser.add_argument("-i", "--model_idx") | |
parser.add_argument("-d", "--device", default="cuda") | |
args = parser.parse_args() | |
cfg.model_path = args.model_path | |
cfg.model_idx = int(args.model_idx) | |
cfg.device = args.device | |
infer = ImageSwap(cfg) | |
def inference(source_face, target_face, shape_rate, id_rate): | |
return infer.inference(source_face, target_face, shape_rate, id_rate) | |
output = gr.Image(shape=None, label="换脸结果") | |
demo = gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.Image(shape=None, label="选脸图"), | |
gr.Image(shape=None, label="目标图"), | |
gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0, | |
step=0.1, | |
label="3d结构相似度(1.0表示完全替换)", | |
), | |
gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0, | |
step=0.1, | |
label="人脸特征相似度(1.0表示完全替换)", | |
), | |
], | |
outputs=output, | |
title="HiConFace人脸融合系统", | |
description="v1.0: developed by yiwise CV group", | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |
infer.inference() | |
if __name__ == "__main__": | |
main() | |