File size: 4,636 Bytes
45b0642
 
 
 
 
 
 
 
b7fc745
45b0642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7fc745
 
45b0642
 
 
 
 
 
25f737f
45b0642
 
 
25f737f
45b0642
 
 
 
25f737f
45b0642
 
b7fc745
 
 
 
45b0642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25f737f
45b0642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9c9560
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# coding: utf-8

"""
The entrance of the gradio
"""

import tyro
import gradio as gr
import spaces
import os.path as osp
from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipeline
from src.config.crop_config import CropConfig
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig


def partial_fields(target_class, kwargs):
    return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})


# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)

# specify configs for inference
inference_cfg = partial_fields(InferenceConfig, args.__dict__)  # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__)  # use attribute of args to initial CropConfig

gradio_pipeline = GradioPipeline(
    inference_cfg=inference_cfg,
    crop_cfg=crop_cfg,
    args=args
)

@spaces.GPU()
def gpu_wrapped_execute_s_video(*args, **kwargs, progress=gr.Progress()):
    return gradio_pipeline.execute_s_video(*args, **kwargs)

# assets
title_md = "assets/gradio_title.md"
example_portrait_dir = "assets/examples/source"
example_video_dir = "assets/examples/driving"

#################### interface logic ####################

# Define components first

output_video1 = gr.Video()
output_video_concat1 = gr.Video()

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    
    with gr.Row():
        # Examples
        gr.Markdown("""
        ## 🤗 This is the gradio demo for LivePortrait for video.
        Please upload or use a webcam to get a Source Portrait Video (any aspect ratio) and upload a Driving Video (1:1 aspect ratio, or any aspect ratio with do crop (driving video) checked).
        """)
    # for video portrait
    with gr.Row():
        with gr.Accordion(open=True, label="Video Portrait"):
            source_video_input = gr.Video()
            gr.Examples(
                examples=[
                    [osp.join(example_video_dir, "d0.mp4")],
                    [osp.join(example_video_dir, "d18.mp4")],
                    [osp.join(example_video_dir, "d19.mp4")],
                    [osp.join(example_video_dir, "d14.mp4")],
                    [osp.join(example_video_dir, "d6.mp4")],
                ],
                inputs=[source_video_input],
                cache_examples=False,
            )
        with gr.Accordion(open=True, label="Driving Video"):
            video_input = gr.Video()
            gr.Examples(
                examples=[
                    [osp.join(example_video_dir, "d0.mp4")],
                    [osp.join(example_video_dir, "d18.mp4")],
                    [osp.join(example_video_dir, "d19.mp4")],
                    [osp.join(example_video_dir, "d14.mp4")],
                    [osp.join(example_video_dir, "d6.mp4")],
                ],
                inputs=[video_input],
                cache_examples=False,
            )
    with gr.Row():
        with gr.Accordion(open=False, label="source Animation Instructions and Options"):
            gr.Markdown(load_description("assets/gradio_description_animation.md"))
            with gr.Row():
                flag_relative_input = gr.Checkbox(value=True, label="relative motion")
                flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
                flag_remap_input = gr.Checkbox(value=True, label="paste-back")
                flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
    with gr.Row():
        with gr.Column():
            process_button_source_animation = gr.Button("🚀 Animate video", variant="primary")
        with gr.Column():
            process_button_reset = gr.ClearButton([source_video_input, video_input, output_video1, output_video_concat1], value="🧹 Clear")
    with gr.Row():
        with gr.Column():
            with gr.Accordion(open=True, label="The animated video in the original image space"):
                output_video1.render()
        with gr.Column():
            with gr.Accordion(open=True, label="The animated video"):
                output_video_concat1.render()

    # binding functions for buttons
   
    process_button_source_animation.click(
        fn=gpu_wrapped_execute_s_video,
        inputs=[
            source_video_input,
            video_input,
            flag_relative_input,
            flag_do_crop_input,
            flag_remap_input,
            flag_crop_driving_video_input
        ],
        outputs=[output_video1, output_video_concat1],
        show_progress=True
    )


demo.launch()