File size: 3,524 Bytes
63074f5
714bf26
 
63074f5
714bf26
b1373ae
75453c0
 
 
 
 
b1373ae
 
63074f5
714bf26
 
 
b1373ae
 
 
63074f5
 
b1373ae
63074f5
 
 
 
b1373ae
63074f5
 
b1373ae
 
 
75453c0
63074f5
 
 
bd2c038
687b293
 
 
 
b1373ae
 
 
63074f5
 
 
 
b1373ae
 
 
75453c0
 
687b293
b1373ae
 
 
 
 
 
72907d1
 
b1373ae
 
 
 
 
 
714bf26
 
b1373ae
 
 
75453c0
 
b1373ae
 
63074f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from model import Model
import gradio as gr
import os
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"

examples = [
    ['Motion 1', "An astronaut dancing in the outer space"],
    ['Motion 2', "An astronaut dancing in the outer space"],
    ['Motion 3', "An astronaut dancing in the outer space"],
    ['Motion 4', "An astronaut dancing in the outer space"],
    ['Motion 5', "An astronaut dancing in the outer space"],
]


def create_demo(model: Model):
    with gr.Blocks() as demo:
        with gr.Row():
            gr.Markdown('## Text and Pose Conditional Video Generation')

        with gr.Row():
            gr.Markdown(
                'Selection: **one motion** and a **prompt**, or use the examples below.')
            with gr.Column():
                gallery_pose_sequence = gr.Gallery(label="Pose Sequence", value=[('__assets__/poses_skeleton_gifs/dance1.gif', "Motion 1"), ('__assets__/poses_skeleton_gifs/dance2.gif', "Motion 2"), (
                    '__assets__/poses_skeleton_gifs/dance3.gif', "Motion 3"), ('__assets__/poses_skeleton_gifs/dance4.gif', "Motion 4"), ('__assets__/poses_skeleton_gifs/dance5.gif', "Motion 5")]).style(grid=[2], height="auto")
                input_video_path = gr.Textbox(
                    label="Pose Sequence", visible=False, value="Motion 1")
                gr.Markdown("## Selection")
                pose_sequence_selector = gr.Markdown(
                    'Pose Sequence: **Motion 1**')
            with gr.Column():
                prompt = gr.Textbox(label='Prompt')
                run_button = gr.Button(label='Run')
                with gr.Accordion('Advanced options', open=False):
                    watermark = gr.Radio(["Picsart AI Research", "Text2Video-Zero",
                                         "None"], label="Watermark", value='Picsart AI Research')
                    chunk_size = gr.Slider(
                        label="Chunk size", minimum=2, maximum=16, value=2, step=1, visible=not on_huggingspace,
                        info="Number of frames processed at once. Reduce for lower memory usage.")
                    merging_ratio = gr.Slider(
                        label="Merging ratio", minimum=0.0, maximum=0.9, step=0.1, value=0.0, visible=not on_huggingspace,
                        info="Ratio of how many tokens are merged. The higher the more compression (less memory and faster inference).")
            with gr.Column():
                result = gr.Image(label="Generated Video")

        input_video_path.change(on_video_path_update,
                                None, pose_sequence_selector)
        gallery_pose_sequence.select(
            pose_gallery_callback, None, input_video_path)
        inputs = [
            input_video_path,
            prompt,
            chunk_size,
            watermark,
            merging_ratio,
        ]

        gr.Examples(examples=examples,
                    inputs=inputs,
                    outputs=result,
                    fn=model.process_controlnet_pose,
                    # cache_examples=on_huggingspace,
                    cache_examples=False,
                    run_on_click=False,
                    )

        run_button.click(fn=model.process_controlnet_pose,
                         inputs=inputs,
                         outputs=result,)

    return demo


def on_video_path_update(evt: gr.EventData):
    return f'Selection: **{evt._data}**'


def pose_gallery_callback(evt: gr.SelectData):
    return f"Motion {evt.index+1}"