Spaces:
Running
Running
import os | |
import platform | |
import uuid | |
import shutil | |
from pydub import AudioSegment | |
import spaces | |
import torch | |
from fastapi import FastAPI, File, UploadFile, Form | |
from fastapi.responses import FileResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from transformers import pipeline | |
from huggingface_hub import snapshot_download | |
from examples.get_examples import get_examples | |
from src.facerender.pirender_animate import AnimateFromCoeff_PIRender | |
from src.utils.preprocess import CropAndExtract | |
from src.test_audio2coeff import Audio2Coeff | |
from src.facerender.animate import AnimateFromCoeff | |
from src.generate_batch import get_data | |
from src.generate_facerender_batch import get_facerender_data | |
from src.utils.init_path import init_path | |
checkpoint_path = 'checkpoints' | |
config_path = 'src/config' | |
device = "cuda" if torch.cuda.is_available() else "mps" if platform.system() == 'Darwin' else "cpu" | |
os.environ['TORCH_HOME'] = checkpoint_path | |
snapshot_download(repo_id='vinthony/SadTalker-V002rc', | |
local_dir=checkpoint_path, local_dir_use_symlinks=True) | |
app = FastAPI() | |
app.mount("/results", StaticFiles(directory="results"), name="results") | |
templates = Jinja2Templates(directory="templates") | |
def mp3_to_wav(mp3_filename, wav_filename, frame_rate): | |
AudioSegment.from_file(file=mp3_filename).set_frame_rate( | |
frame_rate).export(wav_filename, format="wav") | |
def get_pose_style_from_audio(audio_path): | |
emotion_recognizer = pipeline("sentiment-analysis") | |
results = emotion_recognizer(audio_path) | |
emotion = results[0]["label"] | |
pose_style_mapping = { | |
"POSITIVE": 15, | |
"NEGATIVE": 35, | |
"NEUTRAL": 0, | |
} | |
return pose_style_mapping.get(emotion, 0) | |
def generate_video(source_image: str, driven_audio: str, preprocess: str = 'crop', still_mode: bool = False, | |
use_enhancer: bool = False, batch_size: int = 1, size: int = 256, | |
facerender: str = 'facevid2vid', exp_scale: float = 1.0, use_ref_video: bool = False, | |
ref_video: str = None, ref_info: str = None, use_idle_mode: bool = False, | |
length_of_audio: int = 0, use_blink: bool = True, result_dir: str = './results/') -> str: | |
sadtalker_paths = init_path( | |
checkpoint_path, config_path, size, False, preprocess) | |
audio_to_coeff = Audio2Coeff(sadtalker_paths, device) | |
preprocess_model = CropAndExtract(sadtalker_paths, device) | |
animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device) if facerender == 'facevid2vid' and device != 'mps' \ | |
else AnimateFromCoeff_PIRender(sadtalker_paths, device) | |
time_tag = str(uuid.uuid4()) | |
save_dir = os.path.join(result_dir, time_tag) | |
os.makedirs(save_dir, exist_ok=True) | |
input_dir = os.path.join(save_dir, 'input') | |
os.makedirs(input_dir, exist_ok=True) | |
pic_path = os.path.join(input_dir, os.path.basename(source_image)) | |
shutil.move(source_image, input_dir) | |
if driven_audio and os.path.isfile(driven_audio): | |
audio_path = os.path.join(input_dir, os.path.basename(driven_audio)) | |
if '.mp3' in audio_path: | |
mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000) | |
audio_path = audio_path.replace('.mp3', '.wav') | |
else: | |
shutil.move(driven_audio, input_dir) | |
elif use_idle_mode: | |
audio_path = os.path.join( | |
input_dir, 'idlemode_'+str(length_of_audio)+'.wav') | |
AudioSegment.silent( | |
duration=1000*length_of_audio).export(audio_path, format="wav") | |
else: | |
assert use_ref_video and ref_info == 'all' | |
if use_ref_video and ref_info == 'all': | |
ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0] | |
audio_path = os.path.join(save_dir, ref_video_videoname+'.wav') | |
os.system( | |
f"ffmpeg -y -hide_banner -loglevel error -i {ref_video} {audio_path}") | |
ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname) | |
os.makedirs(ref_video_frame_dir, exist_ok=True) | |
ref_video_coeff_path, _, _ = preprocess_model.generate( | |
ref_video, ref_video_frame_dir, preprocess, source_image_flag=False) | |
else: | |
ref_video_coeff_path = None | |
first_frame_dir = os.path.join(save_dir, 'first_frame_dir') | |
os.makedirs(first_frame_dir, exist_ok=True) | |
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate( | |
pic_path, first_frame_dir, preprocess, True, size) | |
if first_coeff_path is None: | |
raise AttributeError("No face is detected") | |
ref_pose_coeff_path, ref_eyeblink_coeff_path = None, None | |
if use_ref_video: | |
if ref_info == 'pose': | |
ref_pose_coeff_path = ref_video_coeff_path | |
elif ref_info == 'blink': | |
ref_eyeblink_coeff_path = ref_video_coeff_path | |
elif ref_info == 'pose+blink': | |
ref_pose_coeff_path = ref_eyeblink_coeff_path = ref_video_coeff_path | |
else: | |
ref_pose_coeff_path = ref_eyeblink_coeff_path = None | |
if use_ref_video and ref_info == 'all': | |
coeff_path = ref_video_coeff_path | |
else: | |
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, | |
still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) | |
pose_style = get_pose_style_from_audio(audio_path) | |
coeff_path = audio_to_coeff.generate( | |
batch, save_dir, pose_style, ref_pose_coeff_path) | |
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, | |
preprocess=preprocess, size=size, expression_scale=exp_scale, facemodel=facerender) | |
return_path = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, | |
preprocess=preprocess, img_size=size) | |
video_name = data['video_name'] | |
print(f'The generated video is named {video_name} in {save_dir}') | |
return return_path | |
async def generate_video_api(source_image: UploadFile = File(...), driven_audio: UploadFile = File(None), | |
preprocess: str = Form('crop'), still_mode: bool = Form(False), | |
use_enhancer: bool = Form(False), batch_size: int = Form(1), size: int = Form(256), | |
facerender: str = Form('facevid2vid'), exp_scale: float = Form(1.0), | |
use_ref_video: bool = Form(False), ref_video: UploadFile = File(None), | |
ref_info: str = Form(None), use_idle_mode: bool = Form(False), | |
length_of_audio: int = Form(0), use_blink: bool = Form(True), result_dir: str = Form('./results/')): | |
temp_source_image_path = f"temp/{source_image.filename}" | |
os.makedirs("temp", exist_ok=True) | |
with open(temp_source_image_path, "wb") as buffer: | |
shutil.copyfileobj(source_image.file, buffer) | |
if driven_audio is not None: | |
temp_driven_audio_path = f"temp/{driven_audio.filename}" | |
with open(temp_driven_audio_path, "wb") as buffer: | |
shutil.copyfileobj(driven_audio.file, buffer) | |
else: | |
temp_driven_audio_path = None | |
if ref_video is not None: | |
temp_ref_video_path = f"temp/{ref_video.filename}" | |
with open(temp_ref_video_path, "wb") as buffer: | |
shutil.copyfileobj(ref_video.file, buffer) | |
else: | |
temp_ref_video_path = None | |
video_path = generate_video( | |
source_image=temp_source_image_path, | |
driven_audio=temp_driven_audio_path, | |
preprocess=preprocess, | |
still_mode=still_mode, | |
use_enhancer=use_enhancer, | |
batch_size=batch_size, | |
size=size, | |
facerender=facerender, | |
exp_scale=exp_scale, | |
use_ref_video=use_ref_video, | |
ref_video=temp_ref_video_path, | |
ref_info=ref_info, | |
use_idle_mode=use_idle_mode, | |
length_of_audio=length_of_audio, | |
use_blink=use_blink, | |
result_dir=result_dir | |
) | |
shutil.rmtree("temp") | |
return FileResponse(video_path) | |
async def root(request): | |
return html | |
# HTML Template (`templates/index.html`) | |
html = """ | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>SadTalker API</title> | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.6.2/dist/css/bootstrap.min.css"> | |
<script src="https://cdn.jsdelivr.net/npm/jquery@3.5.1/dist/jquery.slim.min.js"></script> | |
<script src="https://cdn.jsdelivr.net/npm/popper.js@1.16.1/dist/umd/popper.min.js"></script> | |
<script src="https://cdn.jsdelivr.net/npm/bootstrap@4.6.2/dist/js/bootstrap.min.js"></script> | |
</head> | |
<body> | |
<div class="container mt-5"> | |
<h1>SadTalker API</h1> | |
<form method="POST" action="/generate" enctype="multipart/form-data"> | |
<div class="form-group"> | |
<label for="source_image">Source Image:</label> | |
<input type="file" class="form-control-file" id="source_image" name="source_image" required> | |
</div> | |
<div class="form-group"> | |
<label for="driven_audio">Driving Audio:</label> | |
<input type="file" class="form-control-file" id="driven_audio" name="driven_audio"> | |
</div> | |
<div class="form-group"> | |
<label for="preprocess">Preprocess:</label> | |
<select class="form-control" id="preprocess" name="preprocess"> | |
<option value="crop">Crop</option> | |
<option value="resize">Resize</option> | |
<option value="full">Full</option> | |
<option value="extcrop">ExtCrop</option> | |
<option value="extfull">ExtFull</option> | |
</select> | |
</div> | |
<div class="form-check"> | |
<input type="checkbox" class="form-check-input" id="still_mode" name="still_mode"> | |
<label class="form-check-label" for="still_mode">Still Mode</label> | |
</div> | |
<div class="form-check"> | |
<input type="checkbox" class="form-check-input" id="use_enhancer" name="use_enhancer"> | |
<label class="form-check-label" for="use_enhancer">Use GFPGAN Enhancer</label> | |
</div> | |
<div class="form-group"> | |
<label for="batch_size">Batch Size:</label> | |
<input type="number" class="form-control" id="batch_size" name="batch_size" min="1" max="10" value="1"> | |
</div> | |
<div class="form-group"> | |
<label for="size">Face Model Resolution:</label> | |
<select class="form-control" id="size" name="size"> | |
<option value="256">256</option> | |
<option value="512">512</option> | |
</select> | |
</div> | |
<div class="form-group"> | |
<label for="facerender">Face Render:</label> | |
<select class="form-control" id="facerender" name="facerender"> | |
<option value="facevid2vid">FaceVid2Vid</option> | |
<option value="pirender">PIRender</option> | |
</select> | |
</div> | |
<div class="form-group"> | |
<label for="exp_scale">Expression Scale:</label> | |
<input type="number" class="form-control" id="exp_scale" name="exp_scale" min="0" max="3" step="0.1" value="1.0"> | |
</div> | |
<div class="form-check"> | |
<input type="checkbox" class="form-check-input" id="use_ref_video" name="use_ref_video"> | |
<label class="form-check-label" for="use_ref_video">Use Reference Video</label> | |
</div> | |
<div class="form-group"> | |
<label for="ref_video">Reference Video:</label> | |
<input type="file" class="form-control-file" id="ref_video" name="ref_video"> | |
</div> | |
<div class="form-group"> | |
<label for="ref_info">Reference Video Information:</label> | |
<select class="form-control" id="ref_info" name="ref_info"> | |
<option value="pose">Pose</option> | |
<option value="blink">Blink</option> | |
<option value="pose+blink">Pose + Blink</option> | |
<option value="all">All</option> | |
</select> | |
</div> | |
<div class="form-check"> | |
<input type="checkbox" class="form-check-input" id="use_idle_mode" name="use_idle_mode"> | |
<label class="form-check-label" for="use_idle_mode">Use Idle Animation</label> | |
</div> | |
<div class="form-group"> | |
<label for="length_of_audio">Length of Audio (seconds):</label> | |
<input type="number" class="form-control" id="length_of_audio" name="length_of_audio" min="0" value="0"> | |
</div> | |
<div class="form-check"> | |
<input type="checkbox" class="form-check-input" id="use_blink" name="use_blink" checked> | |
<label class="form-check-label" for="use_blink">Use Eye Blink</label> | |
</div> | |
<button type="submit" class="btn btn-primary">Generate</button> | |
</form> | |
</div> | |
</body> | |
</html> | |
""" | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |