File size: 4,467 Bytes
45b0642
 
 
 
 
 
 
 
b7fc745
45b0642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7fc745
783a96c
45b0642
 
 
 
 
 
25f737f
45b0642
33b60f0
 
 
 
 
 
 
45b0642
33b60f0
0aa927c
45b0642
b7fc745
2ee4092
b7fc745
 
2ee4092
45b0642
2ee4092
 
0aa927c
45b0642
 
 
 
 
 
 
 
 
 
 
0aa927c
 
45b0642
 
 
 
 
 
 
 
 
 
 
 
0aa927c
 
 
 
 
 
 
2ee4092
 
45b0642
4b50302
 
45b0642
 
25f737f
45b0642
 
 
 
 
 
 
 
 
 
 
2ee4092
 
45b0642
 
 
2ee4092
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# coding: utf-8

"""
The entrance of the gradio
"""

import tyro
import gradio as gr
import spaces
import os.path as osp
from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipeline
from src.config.crop_config import CropConfig
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig


def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})


# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)

# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__)  # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__)  # use attribute of args to initial CropConfig

gradio_pipeline = GradioPipeline(
    inference_cfg=inference_cfg,
    crop_cfg=crop_cfg,
    args=args
)

@spaces.GPU()
def gpu_wrapped_execute_s_video(*args, **kwargs):
    return gradio_pipeline.execute_s_video(*args, **kwargs)

# assets
title_md = "assets/gradio_title.md"
example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving"

#################### interface logic ####################
css="""
#col-container {
    max_width: 1400px;
    margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:

    with gr.Column(elem_id="col-container"):
        gr.HTML(load_description(title_md))

        gr.Markdown("""
        ## πŸ€— This is the gradio demo for Vid2Vid LivePortrait.
        Please upload or use a webcam to get a Source Portrait Video (any aspect ratio) and upload a Driving Video (1:1 aspect ratio, or any aspect ratio with do crop (driving video) checked).
        """)

    with gr.Row():
        with gr.Column():
        
            source_video_input = gr.Video(label="Portrait Video Source")
            gr.Examples(
                examples=[
                    [osp.join(example_video_dir, "d0.mp4")],
                    [osp.join(example_video_dir, "d18.mp4")],
                    [osp.join(example_video_dir, "d19.mp4")],
                    [osp.join(example_video_dir, "d14.mp4")],
                    [osp.join(example_video_dir, "d6.mp4")],
                ],
                inputs=[source_video_input],
                cache_examples=False,
            )
            gr.Markdown("β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”")
            video_input = gr.Video(label="Driving Portrait Video")
            gr.Examples(
                examples=[
                    [osp.join(example_video_dir, "d0.mp4")],
                    [osp.join(example_video_dir, "d18.mp4")],
                    [osp.join(example_video_dir, "d19.mp4")],
                    [osp.join(example_video_dir, "d14.mp4")],
                    [osp.join(example_video_dir, "d6.mp4")],
                ],
                inputs=[video_input],
                cache_examples=False,
            )
            gr.Markdown(load_description("assets/gradio_description_animation.md"))
            with gr.Accordion(open=False, label="source Animation Instructions and Options"):
                gr.Markdown(load_description("assets/gradio_description_animation.md"))
                with gr.Row():
                    flag_relative_input = gr.Checkbox(value=True, label="relative motion")
                    flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
                    flag_remap_input = gr.Checkbox(value=True, label="paste-back")
                    flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
            with gr.Row():
                process_button_source_animation = gr.Button("πŸš€ Animate video", variant="primary")
        with gr.Column():
            output_video1 = gr.Video(label="The animated video in the original image space")
            output_video_concat1 = gr.Video(label="The animated video")

    # binding functions for buttons
   
    process_button_source_animation.click(
        fn=gpu_wrapped_execute_s_video,
        inputs=[
            source_video_input,
            video_input,
            flag_relative_input,
            flag_do_crop_input,
            flag_remap_input,
            flag_crop_driving_video_input
        ],
        outputs=[output_video1, output_video_concat1],
        show_progress=True,
        show_api=False
    )


demo.queue(max_size=10).launch(show_api=False)