deepkyu's picture
Reinitialize demo with published github repository. With Gradio 4.x
16c8067
raw history blame
No virus
5.61 kB
import os
import subprocess
from pathlib import Path
import gradio as gr
from config import hparams as hp
from config import hparams_gradio as hp_gradio
from nota_wav2lip import Wav2LipModelComparisonGradio
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = hp_gradio.device
print(f'Using {device} for inference.')
video_label_dict = hp_gradio.sample.video
audio_label_dict = hp_gradio.sample.audio
LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None)
LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None)
LRS_INFERENCE_SAMPLE = os.getenv('LRS_INFERENCE_SAMPLE', None)
if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None:
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True)
if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None:
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True)
path_inference_sample = "sample.tar.gz"
if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None:
subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True)
subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True)
if __name__ == "__main__":
servicer = Wav2LipModelComparisonGradio(
device=device,
video_label_dict=video_label_dict,
audio_label_list=audio_label_dict,
default_video='v1',
default_audio='a1'
)
for video_name in sorted(video_label_dict):
video_stem = Path(video_label_dict[video_name])
servicer.update_video(video_stem, video_stem.with_suffix('.json'),
name=video_name)
for audio_name in sorted(audio_label_dict):
audio_path = Path(audio_label_dict[audio_name])
servicer.update_audio(audio_path, name=audio_name)
with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo:
gr.Markdown(Path('docs/header.md').read_text())
gr.Markdown(Path('docs/description.md').read_text())
with gr.Row():
with gr.Column(variant='panel'):
gr.Markdown('## Select input video and audio', sanitize_html=False)
# Define samples
sample_video = gr.Video(interactive=False, label="Input Video")
sample_audio = gr.Audio(interactive=False, label="Input Audio")
# Define radio inputs
video_selection = gr.components.Radio(video_label_dict,
type='value', label="Select an input video:")
audio_selection = gr.components.Radio(audio_label_dict,
type='value', label="Select an input audio:")
# Define button inputs
with gr.Row(equal_height=True):
generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")
with gr.Column(variant='panel'):
# Define original model output components
gr.Markdown('## Original Wav2Lip')
original_model_output = gr.Video(label="Original Model", interactive=False)
with gr.Column():
with gr.Row(equal_height=True):
original_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
original_model_fps = gr.Textbox(value="", label="FPS")
original_model_params = gr.Textbox(value=servicer.params['wav2lip'], label="# Parameters")
with gr.Column(variant='panel'):
# Define compressed model output components
gr.Markdown('## Compressed Wav2Lip (Ours)')
compressed_model_output = gr.Video(label="Compressed Model", interactive=False)
with gr.Column():
with gr.Row(equal_height=True):
compressed_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
compressed_model_fps = gr.Textbox(value="", label="FPS")
compressed_model_params = gr.Textbox(value=servicer.params['nota_wav2lip'], label="# Parameters")
# Switch video and audio samples when selecting the raido button
video_selection.change(fn=servicer.switch_video_samples, inputs=video_selection, outputs=sample_video)
audio_selection.change(fn=servicer.switch_audio_samples, inputs=audio_selection, outputs=sample_audio)
# Click the generate button for original model
generate_original_button.click(servicer.generate_original_model,
inputs=[video_selection, audio_selection],
outputs=[original_model_output, original_model_inference_time, original_model_fps])
# Click the generate button for compressed model
generate_compressed_button.click(servicer.generate_compressed_model,
inputs=[video_selection, audio_selection],
outputs=[compressed_model_output, compressed_model_inference_time, compressed_model_fps])
gr.Markdown(Path('docs/footer.md').read_text())
demo.queue().launch()