| import os, sys |
| import gradio as gr |
| import torch |
| import shutil |
| from src.gradio_demo import SadTalker |
|
|
| def sadtalker_demo(): |
| |
| checkpoint_path = 'checkpoints' |
| config_path = 'src/config' |
| |
| try: |
| sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True) |
| except Exception as e: |
| print(f"Warning: Could not initialize SadTalker: {e}") |
| sad_talker = None |
|
|
| def generate_video(source_image, driven_audio, preprocess_type, is_still_mode, enhancer, batch_size, size_of_image, pose_style): |
| if sad_talker is None: |
| return "Error: SadTalker not initialized. Please ensure all model files are uploaded." |
| |
| try: |
| return sad_talker.test( |
| source_image=source_image, |
| driven_audio=driven_audio, |
| preprocess=preprocess_type, |
| still_mode=is_still_mode, |
| use_enhancer=enhancer, |
| batch_size=batch_size, |
| size=size_of_image, |
| pose_style=pose_style |
| ) |
| except Exception as e: |
| return f"Error generating video: {str(e)}" |
| with gr.Row().style(equal_height=False): |
| with gr.Column(variant='panel'): |
| with gr.Tabs(elem_id="sadtalker_source_image"): |
| with gr.TabItem('Upload image'): |
| with gr.Row(): |
| source_image = gr.Image( |
| label="Source image", |
| source="upload", |
| type="filepath", |
| elem_id="img2img_image" |
| ).style(width=512) |
|
|
| with gr.Tabs(elem_id="sadtalker_driven_audio"): |
| with gr.TabItem('Upload Audio'): |
| with gr.Column(variant='panel'): |
| driven_audio = gr.Audio( |
| label="Input audio", |
| source="upload", |
| type="filepath" |
| ) |
| |
| with gr.Column(variant='panel'): |
| with gr.Tabs(elem_id="sadtalker_checkbox"): |
| with gr.TabItem('Settings'): |
| gr.Markdown(""" |
| Need help? Please visit our [best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md) for more details |
| """) |
| with gr.Column(variant='panel'): |
| pose_style = gr.Slider( |
| minimum=0, |
| maximum=46, |
| step=1, |
| label="Pose style", |
| value=0 |
| ) |
| size_of_image = gr.Radio( |
| [256, 512], |
| value=256, |
| label='Face model resolution', |
| info="Use 256/512 model?" |
| ) |
| preprocess_type = gr.Radio( |
| ['crop', 'resize','full', 'extcrop', 'extfull'], |
| value='crop', |
| label='preprocess', |
| info="How to handle input image?" |
| ) |
| is_still_mode = gr.Checkbox( |
| label="Still Mode (fewer head motion, works with preprocess `full`)" |
| ) |
| batch_size = gr.Slider( |
| label="Batch size in generation", |
| step=1, |
| maximum=10, |
| value=2 |
| ) |
| enhancer = gr.Checkbox( |
| label="GFPGAN as Face enhancer" |
| ) |
| submit = gr.Button( |
| 'Generate', |
| elem_id="sadtalker_generate", |
| variant='primary' |
| ) |
| |
| with gr.Tabs(elem_id="sadtalker_generated"): |
| gen_video = gr.Video( |
| label="Generated video", |
| format="mp4" |
| ).style(width=512) |
|
|
| submit.click( |
| fn=generate_video, |
| inputs=[ |
| source_image, |
| driven_audio, |
| preprocess_type, |
| is_still_mode, |
| enhancer, |
| batch_size, |
| size_of_image, |
| pose_style |
| ], |
| outputs=[gen_video] |
| ) |
|
|
| return sadtalker_interface |
|
|
| if __name__ == "__main__": |
| demo = sadtalker_demo() |
| demo.queue() |
| demo.launch() |