Spaces:
Configuration error
Configuration error
# -*- coding: utf-8 -*- | |
# @Time : 2024/9/13 0:23 | |
# @Project : FasterLivePortrait | |
# @FileName: api.py | |
import pdb | |
import shutil | |
from typing import Optional, Dict, Any | |
import io | |
import os | |
import subprocess | |
import uvicorn | |
import cv2 | |
import time | |
import numpy as np | |
import os | |
import datetime | |
import platform | |
import pickle | |
from tqdm import tqdm | |
from pydantic import BaseModel | |
from fastapi import APIRouter, Depends, FastAPI, Request, Response, UploadFile | |
from fastapi import File, Body, Form | |
from omegaconf import OmegaConf | |
from fastapi.responses import StreamingResponse | |
from zipfile import ZipFile | |
from src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline | |
from src.utils.utils import video_has_audio | |
from src.utils import logger | |
# model dir | |
project_dir = os.path.dirname(__file__) | |
checkpoints_dir = os.environ.get("FLIP_CHECKPOINT_DIR", os.path.join(project_dir, "checkpoints")) | |
log_dir = os.path.join(project_dir, "logs") | |
os.makedirs(log_dir, exist_ok=True) | |
result_dir = os.path.join(project_dir, "results") | |
os.makedirs(result_dir, exist_ok=True) | |
logger_f = logger.get_logger("faster_liveportrait_api", log_file=os.path.join(log_dir, "log_run.log")) | |
app = FastAPI() | |
global pipe | |
if platform.system().lower() == 'windows': | |
FFMPEG = "third_party/ffmpeg-7.0.1-full_build/bin/ffmpeg.exe" | |
else: | |
FFMPEG = "ffmpeg" | |
def check_all_checkpoints_exist(infer_cfg): | |
""" | |
check whether all checkpoints exist | |
:return: | |
""" | |
ret = True | |
for name in infer_cfg.models: | |
if not isinstance(infer_cfg.models[name].model_path, str): | |
for i in range(len(infer_cfg.models[name].model_path)): | |
infer_cfg.models[name].model_path[i] = infer_cfg.models[name].model_path[i].replace("./checkpoints", | |
checkpoints_dir) | |
if not os.path.exists(infer_cfg.models[name].model_path[i]) and not os.path.exists( | |
infer_cfg.models[name].model_path[i][:-4] + ".onnx"): | |
return False | |
else: | |
infer_cfg.models[name].model_path = infer_cfg.models[name].model_path.replace("./checkpoints", | |
checkpoints_dir) | |
if not os.path.exists(infer_cfg.models[name].model_path) and not os.path.exists( | |
infer_cfg.models[name].model_path[:-4] + ".onnx"): | |
return False | |
for name in infer_cfg.animal_models: | |
if not isinstance(infer_cfg.animal_models[name].model_path, str): | |
for i in range(len(infer_cfg.animal_models[name].model_path)): | |
infer_cfg.animal_models[name].model_path[i] = infer_cfg.animal_models[name].model_path[i].replace( | |
"./checkpoints", | |
checkpoints_dir) | |
if not os.path.exists(infer_cfg.animal_models[name].model_path[i]) and not os.path.exists( | |
infer_cfg.animal_models[name].model_path[i][:-4] + ".onnx"): | |
return False | |
else: | |
infer_cfg.animal_models[name].model_path = infer_cfg.animal_models[name].model_path.replace("./checkpoints", | |
checkpoints_dir) | |
if not os.path.exists(infer_cfg.animal_models[name].model_path) and not os.path.exists( | |
infer_cfg.animal_models[name].model_path[:-4] + ".onnx"): | |
return False | |
# XPOSE | |
xpose_model_path = os.path.join(checkpoints_dir, "liveportrait_animal_onnx/xpose.pth") | |
if not os.path.exists(xpose_model_path): | |
return False | |
embeddings_cache_9_path = os.path.join(checkpoints_dir, "liveportrait_animal_onnx/clip_embedding_9.pkl") | |
if not os.path.exists(embeddings_cache_9_path): | |
return False | |
embeddings_cache_68_path = os.path.join(checkpoints_dir, "liveportrait_animal_onnx/clip_embedding_68.pkl") | |
if not os.path.exists(embeddings_cache_68_path): | |
return False | |
return ret | |
def convert_onnx_to_trt_models(infer_cfg): | |
ret = True | |
for name in infer_cfg.models: | |
if not isinstance(infer_cfg.models[name].model_path, str): | |
for i in range(len(infer_cfg.models[name].model_path)): | |
trt_path = infer_cfg.models[name].model_path[i] | |
onnx_path = trt_path[:-4] + ".onnx" | |
if not os.path.exists(trt_path): | |
convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}" | |
logger_f.info(f"convert onnx model: {onnx_path}") | |
result = subprocess.run(convert_cmd, shell=True, check=True) | |
# 检查结果 | |
if result.returncode == 0: | |
logger_f.info(f"convert onnx model: {onnx_path} successful") | |
else: | |
logger_f.error(f"convert onnx model: {onnx_path} failed") | |
return False | |
else: | |
trt_path = infer_cfg.models[name].model_path | |
onnx_path = trt_path[:-4] + ".onnx" | |
if not os.path.exists(trt_path): | |
convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}" | |
logger_f.info(f"convert onnx model: {onnx_path}") | |
result = subprocess.run(convert_cmd, shell=True, check=True) | |
# 检查结果 | |
if result.returncode == 0: | |
logger_f.info(f"convert onnx model: {onnx_path} successful") | |
else: | |
logger_f.error(f"convert onnx model: {onnx_path} failed") | |
return False | |
for name in infer_cfg.animal_models: | |
if not isinstance(infer_cfg.animal_models[name].model_path, str): | |
for i in range(len(infer_cfg.animal_models[name].model_path)): | |
trt_path = infer_cfg.animal_models[name].model_path[i] | |
onnx_path = trt_path[:-4] + ".onnx" | |
if not os.path.exists(trt_path): | |
convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}" | |
logger_f.info(f"convert onnx model: {onnx_path}") | |
result = subprocess.run(convert_cmd, shell=True, check=True) | |
# 检查结果 | |
if result.returncode == 0: | |
logger_f.info(f"convert onnx model: {onnx_path} successful") | |
else: | |
logger_f.error(f"convert onnx model: {onnx_path} failed") | |
return False | |
else: | |
trt_path = infer_cfg.animal_models[name].model_path | |
onnx_path = trt_path[:-4] + ".onnx" | |
if not os.path.exists(trt_path): | |
convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}" | |
logger_f.info(f"convert onnx model: {onnx_path}") | |
result = subprocess.run(convert_cmd, shell=True, check=True) | |
# 检查结果 | |
if result.returncode == 0: | |
logger_f.info(f"convert onnx model: {onnx_path} successful") | |
else: | |
logger_f.error(f"convert onnx model: {onnx_path} failed") | |
return False | |
return ret | |
async def startup_event(): | |
global pipe | |
# default use trt model | |
cfg_file = os.path.join(project_dir, "configs/trt_infer.yaml") | |
infer_cfg = OmegaConf.load(cfg_file) | |
checkpoints_exist = check_all_checkpoints_exist(infer_cfg) | |
# first: download model if not exist | |
if not checkpoints_exist: | |
download_cmd = f"huggingface-cli download warmshao/FasterLivePortrait --local-dir {checkpoints_dir}" | |
logger_f.info(f"download model: {download_cmd}") | |
result = subprocess.run(download_cmd, shell=True, check=True) | |
# 检查结果 | |
if result.returncode == 0: | |
logger_f.info(f"Download checkpoints to {checkpoints_dir} successful") | |
else: | |
logger_f.error(f"Download checkpoints to {checkpoints_dir} failed") | |
exit(1) | |
# second: convert onnx model to trt | |
convert_ret = convert_onnx_to_trt_models(infer_cfg) | |
if not convert_ret: | |
logger_f.error(f"convert onnx model to trt failed") | |
exit(1) | |
infer_cfg.infer_params.flag_pasteback = True | |
pipe = FasterLivePortraitPipeline(cfg=infer_cfg, is_animal=True) | |
def run_with_video(source_image_path, driving_video_path, save_dir): | |
global pipe | |
ret = pipe.prepare_source(source_image_path, realtime=False) | |
if not ret: | |
logger_f.warning(f"no face in {source_image_path}! exit!") | |
return | |
vcap = cv2.VideoCapture(driving_video_path) | |
fps = int(vcap.get(cv2.CAP_PROP_FPS)) | |
h, w = pipe.src_imgs[0].shape[:2] | |
# render output video | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
vsave_crop_path = os.path.join(save_dir, | |
f"{os.path.basename(source_image_path)}-{os.path.basename(driving_video_path)}-crop.mp4") | |
vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512 * 2, 512)) | |
vsave_org_path = os.path.join(save_dir, | |
f"{os.path.basename(source_image_path)}-{os.path.basename(driving_video_path)}-org.mp4") | |
vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h)) | |
infer_times = [] | |
motion_lst = [] | |
c_eyes_lst = [] | |
c_lip_lst = [] | |
frame_ind = 0 | |
while vcap.isOpened(): | |
ret, frame = vcap.read() | |
if not ret: | |
break | |
t0 = time.time() | |
first_frame = frame_ind == 0 | |
dri_crop, out_crop, out_org, dri_motion_info = pipe.run(frame, pipe.src_imgs[0], pipe.src_infos[0], | |
first_frame=first_frame) | |
frame_ind += 1 | |
if out_crop is None: | |
logger_f.warning(f"no face in driving frame:{frame_ind}") | |
continue | |
motion_lst.append(dri_motion_info[0]) | |
c_eyes_lst.append(dri_motion_info[1]) | |
c_lip_lst.append(dri_motion_info[2]) | |
infer_times.append(time.time() - t0) | |
# print(time.time() - t0) | |
dri_crop = cv2.resize(dri_crop, (512, 512)) | |
out_crop = np.concatenate([dri_crop, out_crop], axis=1) | |
out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR) | |
vout_crop.write(out_crop) | |
out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR) | |
vout_org.write(out_org) | |
vcap.release() | |
vout_crop.release() | |
vout_org.release() | |
if video_has_audio(driving_video_path): | |
vsave_crop_path_new = os.path.splitext(vsave_crop_path)[0] + "-audio.mp4" | |
subprocess.call( | |
[FFMPEG, "-i", vsave_crop_path, "-i", driving_video_path, | |
"-b:v", "10M", "-c:v", | |
"libx264", "-map", "0:v", "-map", "1:a", | |
"-c:a", "aac", | |
"-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"]) | |
vsave_org_path_new = os.path.splitext(vsave_org_path)[0] + "-audio.mp4" | |
subprocess.call( | |
[FFMPEG, "-i", vsave_org_path, "-i", driving_video_path, | |
"-b:v", "10M", "-c:v", | |
"libx264", "-map", "0:v", "-map", "1:a", | |
"-c:a", "aac", | |
"-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"]) | |
logger_f.info(vsave_crop_path_new) | |
logger_f.info(vsave_org_path_new) | |
else: | |
logger_f.info(vsave_crop_path) | |
logger_f.info(vsave_org_path) | |
logger_f.info( | |
"inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000, | |
np.mean(infer_times) * 1000)) | |
# save driving motion to pkl | |
template_dct = { | |
'n_frames': len(motion_lst), | |
'output_fps': fps, | |
'motion': motion_lst, | |
'c_eyes_lst': c_eyes_lst, | |
'c_lip_lst': c_lip_lst, | |
} | |
template_pkl_path = os.path.join(save_dir, | |
f"{os.path.basename(driving_video_path)}.pkl") | |
with open(template_pkl_path, "wb") as fw: | |
pickle.dump(template_dct, fw) | |
logger_f.info(f"save driving motion pkl file at : {template_pkl_path}") | |
def run_with_pkl(source_image_path, driving_pickle_path, save_dir): | |
global pipe | |
ret = pipe.prepare_source(source_image_path, realtime=False) | |
if not ret: | |
logger_f.warning(f"no face in {source_image_path}! exit!") | |
return | |
with open(driving_pickle_path, "rb") as fin: | |
dri_motion_infos = pickle.load(fin) | |
fps = int(dri_motion_infos["output_fps"]) | |
h, w = pipe.src_imgs[0].shape[:2] | |
# render output video | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
vsave_crop_path = os.path.join(save_dir, | |
f"{os.path.basename(source_image_path)}-{os.path.basename(driving_pickle_path)}-crop.mp4") | |
vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512, 512)) | |
vsave_org_path = os.path.join(save_dir, | |
f"{os.path.basename(source_image_path)}-{os.path.basename(driving_pickle_path)}-org.mp4") | |
vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h)) | |
infer_times = [] | |
motion_lst = dri_motion_infos["motion"] | |
c_eyes_lst = dri_motion_infos["c_eyes_lst"] if "c_eyes_lst" in dri_motion_infos else dri_motion_infos[ | |
"c_d_eyes_lst"] | |
c_lip_lst = dri_motion_infos["c_lip_lst"] if "c_lip_lst" in dri_motion_infos else dri_motion_infos["c_d_lip_lst"] | |
frame_num = len(motion_lst) | |
for frame_ind in tqdm(range(frame_num)): | |
t0 = time.time() | |
first_frame = frame_ind == 0 | |
dri_motion_info_ = [motion_lst[frame_ind], c_eyes_lst[frame_ind], c_lip_lst[frame_ind]] | |
out_crop, out_org = pipe.run_with_pkl(dri_motion_info_, pipe.src_imgs[0], pipe.src_infos[0], | |
first_frame=first_frame) | |
if out_crop is None: | |
logger_f.warning(f"no face in driving frame:{frame_ind}") | |
continue | |
infer_times.append(time.time() - t0) | |
# print(time.time() - t0) | |
out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR) | |
vout_crop.write(out_crop) | |
out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR) | |
vout_org.write(out_org) | |
vout_crop.release() | |
vout_org.release() | |
logger_f.info(vsave_crop_path) | |
logger_f.info(vsave_org_path) | |
logger_f.info( | |
"inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000, | |
np.mean(infer_times) * 1000)) | |
class LivePortraitParams(BaseModel): | |
flag_pickle: bool = False | |
flag_relative_input: bool = True | |
flag_do_crop_input: bool = True | |
flag_remap_input: bool = True | |
driving_multiplier: float = 1.0 | |
flag_stitching: bool = True | |
flag_crop_driving_video_input: bool = True | |
flag_video_editing_head_rotation: bool = False | |
flag_is_animal: bool = True | |
scale: float = 2.3 | |
vx_ratio: float = 0.0 | |
vy_ratio: float = -0.125 | |
scale_crop_driving_video: float = 2.2 | |
vx_ratio_crop_driving_video: float = 0.0 | |
vy_ratio_crop_driving_video: float = -0.1 | |
driving_smooth_observation_variance: float = 1e-7 | |
async def upload_files( | |
source_image: Optional[UploadFile] = File(None), | |
driving_video: Optional[UploadFile] = File(None), | |
driving_pickle: Optional[UploadFile] = File(None), | |
flag_is_animal: bool = Form(...), | |
flag_pickle: bool = Form(...), | |
flag_relative_input: bool = Form(...), | |
flag_do_crop_input: bool = Form(...), | |
flag_remap_input: bool = Form(...), | |
driving_multiplier: float = Form(...), | |
flag_stitching: bool = Form(...), | |
flag_crop_driving_video_input: bool = Form(...), | |
flag_video_editing_head_rotation: bool = Form(...), | |
scale: float = Form(...), | |
vx_ratio: float = Form(...), | |
vy_ratio: float = Form(...), | |
scale_crop_driving_video: float = Form(...), | |
vx_ratio_crop_driving_video: float = Form(...), | |
vy_ratio_crop_driving_video: float = Form(...), | |
driving_smooth_observation_variance: float = Form(...) | |
): | |
# 根据传入的表单参数构建 infer_params | |
infer_params = LivePortraitParams( | |
flag_is_animal=flag_is_animal, | |
flag_pickle=flag_pickle, | |
flag_relative_input=flag_relative_input, | |
flag_do_crop_input=flag_do_crop_input, | |
flag_remap_input=flag_remap_input, | |
driving_multiplier=driving_multiplier, | |
flag_stitching=flag_stitching, | |
flag_crop_driving_video_input=flag_crop_driving_video_input, | |
flag_video_editing_head_rotation=flag_video_editing_head_rotation, | |
scale=scale, | |
vx_ratio=vx_ratio, | |
vy_ratio=vy_ratio, | |
scale_crop_driving_video=scale_crop_driving_video, | |
vx_ratio_crop_driving_video=vx_ratio_crop_driving_video, | |
vy_ratio_crop_driving_video=vy_ratio_crop_driving_video, | |
driving_smooth_observation_variance=driving_smooth_observation_variance | |
) | |
global pipe | |
pipe.init_vars() | |
if infer_params.flag_is_animal != pipe.is_animal: | |
pipe.init_models(is_animal=infer_params.flag_is_animal) | |
args_user = { | |
'flag_relative_motion': infer_params.flag_relative_input, | |
'flag_do_crop': infer_params.flag_do_crop_input, | |
'flag_pasteback': infer_params.flag_remap_input, | |
'driving_multiplier': infer_params.driving_multiplier, | |
'flag_stitching': infer_params.flag_stitching, | |
'flag_crop_driving_video': infer_params.flag_crop_driving_video_input, | |
'flag_video_editing_head_rotation': infer_params.flag_video_editing_head_rotation, | |
'src_scale': infer_params.scale, | |
'src_vx_ratio': infer_params.vx_ratio, | |
'src_vy_ratio': infer_params.vy_ratio, | |
'dri_scale': infer_params.scale_crop_driving_video, | |
'dri_vx_ratio': infer_params.vx_ratio_crop_driving_video, | |
'dri_vy_ratio': infer_params.vy_ratio_crop_driving_video, | |
} | |
# update config from user input | |
update_ret = pipe.update_cfg(args_user) | |
# 保存 source_image 到指定目录 | |
temp_dir = os.path.join(result_dir, f"temp-{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}") | |
os.makedirs(temp_dir, exist_ok=True) | |
if source_image and source_image.filename: | |
source_image_path = os.path.join(temp_dir, source_image.filename) | |
with open(source_image_path, "wb") as buffer: | |
buffer.write(await source_image.read()) # 将内容写入文件 | |
else: | |
source_image_path = None | |
if driving_video and driving_video.filename: | |
driving_video_path = os.path.join(temp_dir, driving_video.filename) | |
with open(driving_video_path, "wb") as buffer: | |
buffer.write(await driving_video.read()) # 将内容写入文件 | |
else: | |
driving_video_path = None | |
if driving_pickle and driving_pickle.filename: | |
driving_pickle_path = os.path.join(temp_dir, driving_pickle.filename) | |
with open(driving_pickle_path, "wb") as buffer: | |
buffer.write(await driving_pickle.read()) # 将内容写入文件 | |
else: | |
driving_pickle_path = None | |
save_dir = os.path.join(result_dir, f"{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}") | |
os.makedirs(save_dir, exist_ok=True) | |
if infer_params.flag_pickle: | |
if source_image_path and driving_pickle_path: | |
run_with_pkl(source_image_path, driving_pickle_path, save_dir) | |
else: | |
if source_image_path and driving_video_path: | |
run_with_video(source_image_path, driving_video_path, save_dir) | |
# zip all files and return | |
# 使用 BytesIO 在内存中创建一个字节流 | |
zip_buffer = io.BytesIO() | |
# 使用 ZipFile 将文件夹内容压缩到 zip_buffer 中 | |
with ZipFile(zip_buffer, "w") as zip_file: | |
for root, dirs, files in os.walk(save_dir): | |
for file in files: | |
file_path = os.path.join(root, file) | |
# 添加文件到 ZIP 文件中 | |
zip_file.write(file_path, arcname=os.path.relpath(file_path, save_dir)) | |
# 确保缓冲区指针在开始位置,以便读取整个内容 | |
zip_buffer.seek(0) | |
shutil.rmtree(temp_dir) | |
shutil.rmtree(save_dir) | |
# 通过 StreamingResponse 返回 zip 文件 | |
return StreamingResponse(zip_buffer, media_type="application/zip", | |
headers={"Content-Disposition": "attachment; filename=output.zip"}) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host=os.environ.get("FLIP_IP", "127.0.0.1"), port=os.environ.get("FLIP_PORT", 9871)) | |