File size: 4,336 Bytes
0a9bdfb
3c0f460
a874577
 
0a9bdfb
 
3c0f460
0a9bdfb
 
3c0f460
902b0d6
3c0f460
 
 
 
 
 
 
 
699c0d5
 
0a9bdfb
7ecc5a8
 
2e12f20
 
 
 
7ecc5a8
ddc9ba1
 
eadd195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e12f20
eadd195
 
 
 
 
 
 
 
 
 
 
2e12f20
 
0a9bdfb
eadd195
eb4a36b
 
 
 
 
 
 
 
2e12f20
eb4a36b
 
 
 
 
 
7405324
0a9bdfb
 
902b0d6
 
 
1190e23
 
3c0f460
 
d0257e3
699c0d5
902b0d6
3c0f460
 
 
eb4a36b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import argparse
import os

from musepose_inference import MusePoseInference
from pose_align import PoseAlignmentInference
from downloading_weights import download_models

class App:
    def __init__(self, args):
        self.args = args
        self.pose_alignment_infer = PoseAlignmentInference(
            model_dir=args.model_dir,
            output_dir=args.output_dir
        )
        self.musepose_infer = MusePoseInference(
            model_dir=args.model_dir,
            output_dir=args.output_dir
        )
        if not args.disable_model_download_at_start:
            download_models(model_dir=args.model_dir)

    @staticmethod
    def on_step1_complete(input_img: str, input_pose_vid: str):
        return [
            gr.Image(label="Input Image", value=input_img, type="filepath", scale=5),
            gr.Video(label="Input Aligned Pose Video", value=input_pose_vid, scale=5)
        ]

    def musepose_demo(self):
        with gr.Blocks() as demo:
            self.header()
    
            # 첫 번째 단계: Pose Alignment
            img_pose_input = gr.Image(label="Input Image", type="filepath", scale=5)
            vid_dance_input = gr.Video(label="Input Dance Video", max_length=10, scale=5)
            vid_dance_output = gr.Video(label="Aligned Pose Output", scale=5, interactive=False)
            vid_dance_output_demo = gr.Video(label="Aligned Pose Output Demo", scale=5)
            
            # 두 번째 단계: MusePose Inference
            img_musepose_input = gr.Image(label="Input Image", type="filepath", scale=5)
            vid_pose_input = gr.Video(label="Input Aligned Pose Video", max_length=10, scale=5)
            vid_output = gr.Video(label="MusePose Output", scale=5)
            vid_output_demo = gr.Video(label="MusePose Output Demo", scale=5)
    
            btn_align_pose = gr.Button("ALIGN POSE", variant="primary")
            btn_generate = gr.Button("GENERATE", variant="primary")
            
            btn_align_pose.click(
                fn=self.pose_alignment_infer.align_pose,
                inputs=[vid_dance_input, img_pose_input],
                outputs=[vid_dance_output, vid_dance_output_demo]
            )

            btn_generate.click(
                fn=self.musepose_infer.infer_musepose,
                inputs=[img_musepose_input, vid_pose_input],
                outputs=[vid_output, vid_output_demo]
            )
        
            vid_dance_output.change(
                fn=self.on_step1_complete,
                inputs=[img_pose_input, vid_dance_output],
                outputs=[img_musepose_input, vid_pose_input]
            )

        return demo


    @staticmethod
    def header():
        header = gr.HTML(
            """
            <h1 style="font-size: 23px;">
                <a href="https://github.com/jhj0517/MusePose-WebUI" target="_blank">MusePose WebUI</a>
            </h1>
            <p style="font-size: 18px;">
                <strong>Note</strong>: This space only allows video input up to <strong>10 seconds</strong> because ZeroGPU limits the function runtime to 2 minutes.<br>
                If you want longer video inputs, you have to run it locally. Click the link above and follow the README to try it locally.<br><br>
                When you have completed the <strong>1: Pose Alignment</strong> process, go to <strong>2: MusePose Inference</strong> and click the "GENERATE" button.
            </p>
            """
        )
        return header

    def launch(self):
        demo = self.musepose_demo()
        demo.queue().launch(
            share=self.args.share
        )

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default=os.path.join("pretrained_weights"), help='Pretrained models directory for MusePose')
    parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Output directory for the result')
    parser.add_argument('--disable_model_download_at_start', type=bool, default=False, nargs='?', const=True, help='Disable model download at start or not')
    parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio makes sharable link if it is true')
    args = parser.parse_args()

    app = App(args=args)
    app.launch()