File size: 4,357 Bytes
0a9bdfb
3c0f460
a874577
 
0a9bdfb
 
3c0f460
0a9bdfb
a874577
0a9bdfb
3c0f460
902b0d6
3c0f460
 
 
 
 
 
 
 
699c0d5
 
0a9bdfb
7ecc5a8
 
 
 
 
eb4a36b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538c8ef
 
 
eb4a36b
 
 
 
 
 
 
 
538c8ef
eb4a36b
538c8ef
 
 
eb4a36b
0a9bdfb
eb4a36b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7405324
0a9bdfb
 
902b0d6
 
 
1190e23
 
 
3c0f460
 
d0257e3
699c0d5
902b0d6
3c0f460
 
 
eb4a36b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import argparse
import os

from musepose_inference import MusePoseInference
from pose_align import PoseAlignmentInference
from downloading_weights import download_models


class App:
    def __init__(self, args):
        self.args = args
        self.pose_alignment_infer = PoseAlignmentInference(
            model_dir=args.model_dir,
            output_dir=args.output_dir
        )
        self.musepose_infer = MusePoseInference(
            model_dir=args.model_dir,
            output_dir=args.output_dir
        )
        if not args.disable_model_download_at_start:
            download_models(model_dir=args.model_dir)

    @staticmethod
    def on_step1_complete(input_img: str, input_pose_vid: str):
        return [gr.Image(label="Input Image", value=input_img, type="filepath", scale=5),
                gr.Video(label="Input Aligned Pose Video", value=input_pose_vid, scale=5)]

    def musepose_demo(self):
        with gr.Blocks() as demo:
            md_header = self.header()
            with gr.Tabs():
                with gr.TabItem('1: Pose Alignment'):
                    with gr.Row():
                        with gr.Column(scale=3):
                            img_pose_input = gr.Image(label="Input Image", type="filepath", scale=5)
                            vid_dance_input = gr.Video(label="Input Dance Video", max_length=10, scale=5)  # Changed max_length to 10
                        with gr.Column(scale=3):
                            vid_dance_output = gr.Video(label="Aligned Pose Output", scale=5, interactive=False)
                            vid_dance_output_demo = gr.Video(label="Aligned Pose Output Demo", scale=5)
                        # Rest of the column setup remains the same
                        with gr.Column(scale=3):
                            # column settings remain the same

                # button settings and event handlers remain the same

                with gr.TabItem('2: MusePose Inference'):
                    with gr.Row():
                        with gr.Column(scale=3):
                            img_musepose_input = gr.Image(label="Input Image", type="filepath", scale=5)
                            vid_pose_input = gr.Video(label="Input Aligned Pose Video", max_length=10, scale=5)  # Changed max_length to 10
                        with gr.Column(scale=3):
                            vid_output = gr.Video(label="MusePose Output", scale=5)
                            vid_output_demo = gr.Video(label="MusePose Output Demo", scale=5)

                        # Rest of the settings remains the same

            # Event handler and button settings remain unchanged

        return demo

    @staticmethod
    def header():
        header = gr.HTML(
            """
            <h1 style="font-size: 23px;">
                <a href="https://github.com/jhj0517/MusePose-WebUI" target="_blank">MusePose WebUI</a>
            </h1>
            <p style="font-size: 18px;">
                <strong>Note</strong>: This space now allows video input up to <strong>10 seconds</strong> because ZeroGPU limits the function runtime to 2 minutes. <br>
                If you want longer video inputs, you have to run it locally. Click the link above and follow the README to try it locally.<br><br>
                When you have completed the <strong>1: Pose Alignment</strong> process, go to <strong>2: MusePose Inference</strong> and click the "GENERATE" button.
            </p>
            """
        )
        return header

    def launch(self):
        demo = self.musepose_demo()
        demo.queue().launch(
            share=self.args.share
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default=os.path.join("pretrained_weights"), help='Pretrained models directory for MusePose')
    parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Output directory for the result')
    parser.add_argument('--disable_model_download_at_start', type=bool, default=False, nargs='?', const=True, help='Disable model download at start or not')
    parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio makes sharable link if it is true')
    args = parser.parse_args()

    app = App(args=args)
    app.launch()