gavinyuan
add: app.py
de68d44
raw
history blame
18.1 kB
import os
import uuid
import glob
import shutil
from pathlib import Path
from multiprocessing.pool import Pool
import gradio as gr
import torch
from torchvision import transforms
import cv2
import numpy as np
from PIL import Image
import tqdm
# from modules.networks.faceshifter import FSGenerator
# from inference.alignment import norm_crop, norm_crop_with_M, paste_back
# from inference.utils import save, get_5_from_98, get_detector, get_lmk
# from inference.PIPNet.lib.tools import get_lmk_model, demo_image
# from inference.landmark_smooth import kalman_filter_landmark, savgol_filter_landmark
# from tricks import Trick
# make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn))
#
#
# fs_model_name = 'faceshifter'
# in_size = 512
#
# mouth_net_param = {
# "use": True,
# "feature_dim": 128,
# "crop_param": (28, 56, 84, 112),
# "weight_path": "../../modules/third_party/arcface/weights/mouth_net_28_56_84_112.pth",
# }
# trick = Trick()
#
# T = transforms.Compose(
# [
# transforms.ToTensor(),
# transforms.Normalize(0.5, 0.5),
# ]
# )
# tensor2pil_transform = transforms.ToPILImage()
#
#
# def extract_generator(ckpt: str, pt: str):
# print(f'[extract_generator] loading ckpt...')
# from trainer.faceshifter.faceshifter_pl import FaceshifterPL512, FaceshifterPL
# import yaml
# with open(make_abs_path('../../trainer/faceshifter/config.yaml'), 'r') as f:
# config = yaml.load(f, Loader=yaml.FullLoader)
# config['mouth_net'] = mouth_net_param
#
# if in_size == 256:
# net = FaceshifterPL(n_layers=3, num_D=3, config=config)
# elif in_size == 512:
# net = FaceshifterPL512(n_layers=3, num_D=3, config=config, verbose=False)
# else:
# raise ValueError('Not supported in_size.')
# checkpoint = torch.load(ckpt, map_location="cpu", )
# net.load_state_dict(checkpoint["state_dict"], strict=False)
# net.eval()
#
# G = net.generator
# torch.save(G.state_dict(), pt)
# print(f'[extract_generator] extracted from {ckpt}, pth saved to {pt}')
#
#
# ''' load model '''
# if fs_model_name == 'faceshifter':
# # pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t38.pth")
# # pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_6.pth")
# # ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_6/epoch=3-step=128999.ckpt"
# pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_t512_4.pth")
# ckpt_path = "/apdcephfs/share_1290939/gavinyuan/out/triplet512_4/epoch=2-step=185999.ckpt"
# if not os.path.exists(pt_path) or 't512' in pt_path:
# extract_generator(ckpt_path, pt_path)
# fs_model = FSGenerator(
# make_abs_path("../../modules/third_party/arcface/weights/ms1mv3_arcface_r100_fp16/backbone.pth"),
# mouth_net_param=mouth_net_param,
# in_size=in_size,
# downup=in_size == 512,
# )
# fs_model.load_state_dict(torch.load(pt_path, "cpu"), strict=True)
# fs_model.eval()
#
# @torch.no_grad()
# def infer_batch_to_img(i_s, i_t, post: bool = False):
# i_r = fs_model(i_s, i_t)[0] # x, id_vector, att
#
# if post:
# target_hair_mask = trick.get_any_mask(i_t, par=[0, 17])
# target_hair_mask = trick.smooth_mask(target_hair_mask)
# i_r = target_hair_mask * i_t + (target_hair_mask * (-1) + 1) * i_r
# i_r = trick.finetune_mouth(i_s, i_t, i_r) if in_size == 256 else i_r
#
# img_r = trick.tensor_to_arr(i_r)[0]
# return img_r
#
# elif fs_model_name == 'simswap_triplet' or fs_model_name == 'simswap_vanilla':
# from modules.networks.simswap import Generator_Adain_Upsample
# sw_model = Generator_Adain_Upsample(
# input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False,
# mouth_net_param=mouth_net_param
# )
# if fs_model_name == 'simswap_triplet':
# pt_path = make_abs_path("../ffplus/extracted_ckpt/G_mouth1_st5.pth")
# ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/"
# "simswap_triplet_5/epoch=12-step=782999.ckpt")
# elif fs_model_name == 'simswap_vanilla':
# pt_path = make_abs_path("../ffplus/extracted_ckpt/G_tmp_sv4_off.pth")
# ckpt_path = make_abs_path("/apdcephfs/share_1290939/gavinyuan/out/"
# "simswap_vanilla_4/epoch=694-step=1487999.ckpt")
# else:
# pt_path = None
# ckpt_path = None
# sw_model.load_state_dict(torch.load(pt_path, "cpu"), strict=False)
# sw_model.eval()
# fs_model = sw_model
#
# from trainer.simswap.simswap_pl import SimSwapPL
# import yaml
# with open(make_abs_path('../../trainer/simswap/config.yaml'), 'r') as f:
# config = yaml.load(f, Loader=yaml.FullLoader)
# config['mouth_net'] = mouth_net_param
# net = SimSwapPL(config=config, use_official_arc='off' in pt_path)
#
# checkpoint = torch.load(ckpt_path, map_location="cpu")
# net.load_state_dict(checkpoint["state_dict"], strict=False)
# net.eval()
# sw_mouth_net = net.mouth_net # maybe None
# sw_netArc = net.netArc
# fs_model = fs_model.cuda()
# sw_mouth_net = sw_mouth_net.cuda() if sw_mouth_net is not None else sw_mouth_net
# sw_netArc = sw_netArc.cuda()
#
# @torch.no_grad()
# def infer_batch_to_img(i_s, i_t, post: bool = False):
# i_r = fs_model(source=i_s, target=i_t, net_arc=sw_netArc, mouth_net=sw_mouth_net,)
# if post:
# target_hair_mask = trick.get_any_mask(i_t, par=[0, 17])
# target_hair_mask = trick.smooth_mask(target_hair_mask)
# i_r = target_hair_mask * i_t + (target_hair_mask * (-1) + 1) * i_r
# i_r = i_r.clamp(-1, 1)
# i_r = trick.tensor_to_arr(i_r)[0]
# return i_r
#
# elif fs_model_name == 'simswap_official':
# from simswap.image_infer import SimSwapOfficialImageInfer
# fs_model = SimSwapOfficialImageInfer()
# pt_path = 'Simswap Official'
# mouth_net_param = {
# "use": False
# }
#
# @torch.no_grad()
# def infer_batch_to_img(i_s, i_t):
# i_r = fs_model.image_infer(source_tensor=i_s, target_tensor=i_t)
# i_r = i_r.clamp(-1, 1)
# return i_r
#
# else:
# raise ValueError('Not supported fs_model_name.')
#
#
# print(f'[demo] model loaded from {pt_path}')
def swap_image(
source_image,
target_path,
out_path,
transform,
G,
align_source="arcface",
align_target="set1",
gpu_mode=True,
paste_back=True,
use_post=False,
use_gpen=False,
in_size=256,
):
name = target_path.split("/")[-1]
name = "out_" + name
if isinstance(G, torch.nn.Module):
G.eval()
if gpu_mode:
G = G.cuda()
source_img = np.array(Image.open(source_image).convert("RGB"))
net, detector = get_lmk_model()
lmk = get_5_from_98(demo_image(source_img, net, detector)[0])
source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0)
source_img = transform(source_img).unsqueeze(0)
target = np.array(Image.open(target_path).convert("RGB"))
original_target = target.copy()
lmk = get_5_from_98(demo_image(target, net, detector)[0])
target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0)
target = transform(target).unsqueeze(0)
if gpu_mode:
target = target.cuda()
source_img = source_img.cuda()
cv2.imwrite('cropped_source.png', trick.tensor_to_arr(source_img)[0, :, :, ::-1])
cv2.imwrite('cropped_target.png', trick.tensor_to_arr(target)[0, :, :, ::-1])
# both inputs should be 512
result = infer_batch_to_img(source_img, target, post=use_post)
cv2.imwrite('result.png', result[:, :, ::-1])
os.makedirs(out_path, exist_ok=True)
Image.fromarray(result.astype(np.uint8)).save(os.path.join(out_path, name))
save((result, M, original_target, os.path.join(out_path, "paste_back_" + name), None),
trick=trick, use_post=use_gpen)
def process_video(
source_image,
target_path,
out_path,
transform,
G,
align_source="arcface",
align_target="set1",
gpu_mode=True,
frames=9999999,
use_tddfav2=False,
landmark_smooth="kalman",
):
if isinstance(G, torch.nn.Module):
G.eval()
if gpu_mode:
G = G.cuda()
''' Target video to frames (.png) '''
fps = 25.0
if not os.path.isdir(target_path):
vidcap = cv2.VideoCapture(target_path)
fps = vidcap.get(cv2.CAP_PROP_FPS)
try:
for match in glob.glob(os.path.join("./tmp/", "*.png")):
os.remove(match)
for match in glob.glob(os.path.join(out_path, "*.png")):
os.remove(match)
except Exception as e:
print(e)
os.makedirs("./tmp/", exist_ok=True)
os.system(
f"ffmpeg -i {target_path} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 ./tmp/frame_%05d.png"
)
target_path = "./tmp/"
globbed_images = sorted(glob.glob(os.path.join(target_path, "*.png")))
''' Get target landmarks '''
print('[Extracting target landmarks...]')
if not use_tddfav2:
align_net, align_detector = get_lmk_model()
else:
align_net, align_detector = get_detector(gpu_mode=gpu_mode)
target_lmks = []
for frame_path in tqdm.tqdm(globbed_images):
target = np.array(Image.open(frame_path).convert("RGB"))
lmk = demo_image(target, align_net, align_detector)
lmk = lmk[0]
target_lmks.append(lmk)
''' Landmark smoothing '''
target_lmks = np.array(target_lmks, np.float32) # (#frames, 98, 2)
if landmark_smooth == 'kalman':
target_lmks = kalman_filter_landmark(target_lmks,
process_noise=0.01,
measure_noise=0.01).astype(np.int)
elif landmark_smooth == 'savgol':
target_lmks = savgol_filter_landmark(target_lmks).astype(np.int)
elif landmark_smooth == 'cancel':
target_lmks = target_lmks.astype(np.int)
else:
raise KeyError('Not supported landmark_smooth choice')
''' Crop source image '''
source_img = np.array(Image.open(source_image).convert("RGB"))
if not use_tddfav2:
lmk = get_5_from_98(demo_image(source_img, align_net, align_detector)[0])
else:
lmk = get_lmk(source_img, align_net, align_detector)
source_img = norm_crop(source_img, lmk, in_size, mode=align_source, borderValue=0.0)
source_img = transform(source_img).unsqueeze(0)
if gpu_mode:
source_img = source_img.cuda()
''' Process by frames '''
targets = []
t_facial_masks = []
Ms = []
original_frames = []
names = []
count = 0
for image in tqdm.tqdm(globbed_images):
names.append(os.path.join(out_path, Path(image).name))
target = np.array(Image.open(image).convert("RGB"))
original_frames.append(target)
''' Crop target frames '''
lmk = get_5_from_98(target_lmks[count])
target, M = norm_crop_with_M(target, lmk, in_size, mode=align_target, borderValue=0.0)
target = transform(target).unsqueeze(0) # in [-1,1]
if gpu_mode:
target = target.cuda()
''' Finetune paste masks '''
target_facial_mask = trick.get_any_mask(target,
par=[1, 2, 3, 4, 5, 6, 10, 11, 12, 13]).squeeze() # in [0,1]
target_facial_mask = target_facial_mask.cpu().numpy().astype(np.float)
target_facial_mask = trick.finetune_mask(target_facial_mask, target_lmks) # in [0,1]
t_facial_masks.append(target_facial_mask)
''' Face swapping '''
with torch.no_grad():
if 'faceshifter' in fs_model_name:
output = G(source_img, target)
target_hair_mask = trick.get_any_mask(target, par=[0, 17])
target_hair_mask = trick.smooth_mask(target_hair_mask)
output = target_hair_mask * target + (target_hair_mask * (-1) + 1) * output
output = trick.finetune_mouth(source_img, target, output)
elif 'simswap' in fs_model_name and 'official' not in fs_model_name:
output = fs_model(source=source_img, target=target,
net_arc=sw_netArc, mouth_net=sw_mouth_net,)
if 'vanilla' not in fs_model_name:
target_hair_mask = trick.get_any_mask(target, par=[0, 17])
target_hair_mask = trick.smooth_mask(target_hair_mask)
output = target_hair_mask * target + (target_hair_mask * (-1) + 1) * output
output = trick.finetune_mouth(source_img, target, output)
output = output.clamp(-1, 1)
elif 'simswap_official' in fs_model_name:
output = fs_model.image_infer(source_tensor=source_img, target_tensor=target)
output = output.clamp(-1, 1)
if isinstance(output, tuple):
target = output[0][0] * 0.5 + 0.5
else:
target = output[0] * 0.5 + 0.5
targets.append(np.array(tensor2pil_transform(target)))
Ms.append(M)
count += 1
if count > frames:
break
os.makedirs(out_path, exist_ok=True)
return targets, t_facial_masks, Ms, original_frames, names, fps
def swap_image_gr(img1, img2, use_post=False, use_gpen=False, gpu_mode=True):
root_dir = make_abs_path("./online_data")
req_id = uuid.uuid1().hex
data_dir = os.path.join(root_dir, req_id)
os.makedirs(data_dir, exist_ok=True)
source_path = os.path.join(data_dir, "source.png")
target_path = os.path.join(data_dir, "target.png")
filename = "paste_back_out_target.png"
out_path = os.path.join(data_dir, filename)
cv2.imwrite(source_path, img1[:, :, ::-1])
cv2.imwrite(target_path, img2[:, :, ::-1])
swap_image(
source_path,
target_path,
data_dir,
T,
fs_model,
gpu_mode=gpu_mode,
align_target='ffhq',
align_source='ffhq',
use_post=use_post,
use_gpen=use_gpen,
in_size=in_size,
)
out = cv2.imread(out_path)[..., ::-1]
return out
def swap_video_gr(img1, target_path, use_gpu=True, frames=9999999):
root_dir = make_abs_path("./online_data")
req_id = uuid.uuid1().hex
data_dir = os.path.join(root_dir, req_id)
os.makedirs(data_dir, exist_ok=True)
source_path = os.path.join(data_dir, "source.png")
cv2.imwrite(source_path, img1[:, :, ::-1])
out_dir = os.path.join(data_dir, "out")
out_name = "output.mp4"
targets, t_facial_masks, Ms, original_frames, names, fps = process_video(
source_path,
target_path,
out_dir,
T,
fs_model,
gpu_mode=use_gpu,
frames=frames,
align_target='ffhq',
align_source='ffhq',
use_tddfav2=False,
)
pool_process = 170
audio = True
concat = False
if pool_process <= 1:
for target, M, original_target, name, t_facial_mask in tqdm.tqdm(
zip(targets, Ms, original_frames, names, t_facial_masks)
):
if M is None or target is None:
Image.fromarray(original_target.astype(np.uint8)).save(name)
continue
Image.fromarray(paste_back(np.array(target), M, original_target, t_facial_mask)).save(name)
else:
with Pool(pool_process) as pool:
pool.map(save, zip(targets, Ms, original_frames, names, t_facial_masks))
video_save_path = os.path.join(out_dir, out_name)
if audio:
print("use audio")
os.system(
f"ffmpeg -y -r {fps} -i {out_dir}/frame_%05d.png -i {target_path}"
f" -map 0:v:0 -map 1:a:0? -c:a copy -c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p {video_save_path}"
)
else:
print("no audio")
os.system(
f"ffmpeg -y -r {fps} -i ./tmp/frame_%05d.png "
f"-c:v libx264 -r {fps} -crf 10 -pix_fmt yuv420p {video_save_path}"
)
# ffmpeg -i left.mp4 -i right.mp4 -filter_complex hstack output.mp4
if concat:
concat_video_save_path = os.path.join(out_dir, "concat_" + out_name)
os.system(
f"ffmpeg -y -i {target_path} -i {video_save_path} -filter_complex hstack {concat_video_save_path}"
)
# delete tmp file
shutil.rmtree("./tmp/")
for match in glob.glob(os.path.join(out_dir, "*.png")):
os.remove(match)
print(video_save_path)
return video_save_path
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.Markdown("SuperSwap")
with gr.Tab("Image"):
with gr.Row(equal_height=True):
with gr.Column(scale=3):
image1_input = gr.Image()
image2_input = gr.Image()
use_post = gr.Checkbox(label="后处理")
use_gpen = gr.Checkbox(label="超分增强")
with gr.Column(scale=2):
image_output = gr.Image()
image_button = gr.Button("换脸")
with gr.Tab("Video"):
with gr.Row(equal_height=True):
with gr.Column(scale=3):
image3_input = gr.Image()
video_input = gr.Video()
with gr.Column(scale=2):
video_output = gr.Video()
video_button = gr.Button("换脸")
image_button.click(
swap_image_gr,
inputs=[image1_input, image2_input, use_post, use_gpen],
outputs=image_output,
)
video_button.click(
swap_video_gr,
inputs=[image3_input, video_input],
outputs=video_output,
)
demo.launch(server_name="0.0.0.0", server_port=7860)