File size: 4,383 Bytes
45b0642 b7fc745 45b0642 b7fc745 783a96c 45b0642 25f737f 45b0642 33b60f0 45b0642 33b60f0 0aa927c 45b0642 b7fc745 2ee4092 b7fc745 2ee4092 45b0642 2ee4092 0aa927c 45b0642 0aa927c 45b0642 0aa927c 2ee4092 45b0642 4b50302 45b0642 25f737f 45b0642 2ee4092 45b0642 2ee4092 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# coding: utf-8
"""
The entrance of the gradio
"""
import tyro
import gradio as gr
import spaces
import os.path as osp
from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipeline
from src.config.crop_config import CropConfig
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
gradio_pipeline = GradioPipeline(
inference_cfg=inference_cfg,
crop_cfg=crop_cfg,
args=args
)
@spaces.GPU()
def gpu_wrapped_execute_s_video(*args, **kwargs):
return gradio_pipeline.execute_s_video(*args, **kwargs)
# assets
title_md = "assets/gradio_title.md"
example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving"
#################### interface logic ####################
css="""
#col-container {
max_width: 1400px;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(load_description(title_md))
gr.Markdown("""
## π€ This is the gradio demo for Vid2Vid LivePortrait.
Please upload or use a webcam to get a Source Portrait Video (any aspect ratio) and upload a Driving Video (1:1 aspect ratio, or any aspect ratio with do crop (driving video) checked).
""")
with gr.Row():
with gr.Column():
source_video_input = gr.Video(label="Portrait Video Source")
gr.Examples(
examples=[
[osp.join(example_video_dir, "d0.mp4")],
[osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
],
inputs=[source_video_input],
cache_examples=False,
)
gr.Markdown("ββββββββββ")
video_input = gr.Video(label="Driving Portrait Video")
gr.Examples(
examples=[
[osp.join(example_video_dir, "d0.mp4")],
[osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
],
inputs=[video_input],
cache_examples=False,
)
with gr.Accordion(open=False, label="source Animation Instructions and Options"):
gr.Markdown(load_description("assets/gradio_description_animation.md"))
with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
with gr.Row():
process_button_source_animation = gr.Button("π Animate video", variant="primary")
with gr.Column():
output_video1 = gr.Video(label="The animated video in the original image space")
output_video_concat1 = gr.Video(label="The animated video")
# binding functions for buttons
process_button_source_animation.click(
fn=gpu_wrapped_execute_s_video,
inputs=[
source_video_input,
video_input,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
flag_crop_driving_video_input
],
outputs=[output_video1, output_video_concat1],
show_progress=True,
show_api=False
)
demo.queue(max_size=10).launch(show_api=False)
|