sutra-avatar-v2 / base_task_executor.py
mikesapi's picture
initial commit of sutra-avatar-v2
ec17e66
import os
import random
import re
import shutil
import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
import gradio as gr
from elevenlabs_helper import ElevenLabsHelper
# ---
talk_key = "talk"
# ---
valid_image_exts = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp")
def is_image(file_path):
return file_path.lower().endswith(valid_image_exts)
def get_formatted_datetime_name() -> str:
d = datetime.now()
return d.strftime("d%y%m%d" + "-" + "t%H%M%S")
def get_name_ext(filepath):
filepath = os.path.abspath(filepath)
_, name_ext = os.path.split(filepath)
name, ext = os.path.splitext(name_ext)
return name, ext
def sanitize_string(string):
sanitized_string = re.sub(r"[^A-Za-z0-9]", "", string)
max_len = 15
return sanitized_string[:max_len]
def get_output_video_name(
input_base_path, input_driving_path, base_motion_expression, input_driving_audio_path, tag=""
):
if not tag:
tag = get_formatted_datetime_name()
base_name, _ = get_name_ext(input_base_path)
base_name = sanitize_string(base_name)
driving_name = ""
if input_driving_path:
driving_name, _ = get_name_ext(input_driving_path)
driving_name = sanitize_string(driving_name)
elif base_motion_expression and is_image(input_base_path):
driving_name = base_motion_expression
audio_name = ""
if input_driving_audio_path:
audio_name, _ = get_name_ext(input_driving_audio_path)
audio_name = sanitize_string(audio_name)
output_video_name = f"{tag}--b-{base_name}"
if driving_name:
output_video_name += f"--d-{driving_name}"
if audio_name:
output_video_name += f"--a-{audio_name}"
return output_video_name
def generate_random_integer(num_digits):
current_time = int(time.time() * 1000)
random.seed(current_time)
lower_bound = 0
upper_bound = (10**num_digits) - 1
return random.randint(lower_bound, upper_bound)
def get_unique_name(maxd=4, delim="-"):
pid = os.getpid()
pid_str = str(pid)[-maxd:]
time_ns = time.time_ns()
time_str = str(time_ns)[-maxd:]
rint = generate_random_integer(maxd)
rint_str = str(rint).zfill(maxd)
return delim.join([pid_str, time_str, rint_str])
def mkdir_p(path: str) -> None:
if not Path(path).exists():
Path(path).mkdir(parents=True)
# ---
class BaseTaskExecutor(ABC):
def __init__(self):
self.tmp_dir = "/tmp/gradio"
def execute_task(
self, input_base_path, base_motion_expression, input_driving_audio_path, driving_text_input, driving_voice_input
):
tag = get_unique_name()
output_dir = os.path.join(self.tmp_dir, tag)
mkdir_p(output_dir)
do_dafile = input_driving_audio_path is not None and os.path.exists(input_driving_audio_path)
do_datts = driving_text_input and driving_voice_input
do_talk = do_dafile or do_datts
if base_motion_expression:
if talk_key not in base_motion_expression and do_talk:
gr.Warning(
f"Ignoring Driving Audio since expressive Base Motion selected: {base_motion_expression}")
do_dafile = False
do_datts = False
do_talk = False
if talk_key in base_motion_expression and not do_talk:
gr.Warning(f"Selected talking Base Motion but no Driving Audio")
else:
base_motion_expression = ""
if do_datts:
if do_dafile:
gr.Warning("Ignoring Audio File input since TTS is selected.\nClear the undesired input if this is not intended.")
output_audio_file = os.path.join(f"{output_dir}/{tag}.mp3")
ElevenLabsHelper.generate_voice(driving_text_input, driving_voice_input, output_audio_file)
input_driving_audio_path = output_audio_file
if not do_talk:
input_driving_audio_path = ""
if input_base_path is not None and os.path.exists(input_base_path):
input_driving_path = ""
request_id = get_unique_name(maxd=8, delim="")
output_video_path = os.path.join(
self.tmp_dir,
get_output_video_name(
input_base_path, input_driving_path, base_motion_expression, input_driving_audio_path
)
+ ".mp4",
)
result, output_video_path = self.generate(
input_base_path,
input_driving_path,
base_motion_expression,
input_driving_audio_path,
output_video_path,
request_id,
)
success = result["success"]
messages = result["messages"]
self.clean(output_dir)
if success:
return output_video_path, gr.update(visible=True), messages
else:
gr.Info("Task could not be completed", duration=4)
return None, gr.update(visible=False), f"ERROR\n\n{messages}"
else:
self.clean(output_dir)
raise gr.Error("No source selected!", duration=6)
@abstractmethod
def generate(self):
pass
def clean(self, output_dir):
if os.path.isdir(output_dir):
shutil.rmtree(output_dir)