File size: 4,165 Bytes
0a9bdfb
3c0f460
a874577
 
0a9bdfb
 
3c0f460
0a9bdfb
 
3c0f460
902b0d6
3c0f460
 
 
 
 
 
 
 
699c0d5
 
0a9bdfb
7ecc5a8
 
 
 
 
ddc9ba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a9bdfb
eb4a36b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7405324
0a9bdfb
 
902b0d6
 
 
1190e23
 
3c0f460
 
d0257e3
699c0d5
902b0d6
3c0f460
 
 
eb4a36b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import argparse
import os

from musepose_inference import MusePoseInference
from pose_align import PoseAlignmentInference
from downloading_weights import download_models

class App:
    def __init__(self, args):
        self.args = args
        self.pose_alignment_infer = PoseAlignmentInference(
            model_dir=args.model_dir,
            output_dir=args.output_dir
        )
        self.musepose_infer = MusePoseInference(
            model_dir=args.model_dir,
            output_dir=args.output_dir
        )
        if not args.disable_model_download_at_start:
            download_models(model_dir=args.model_dir)

    @staticmethod
    def on_step1_complete(input_img: str, input_pose_vid: str):
        return [gr.Image(label="Input Image", value=input_img, type="filepath", scale=5),
                gr.Video(label="Input Aligned Pose Video", value=input_pose_vid, scale=5)]

    def musepose_demo(self):
        with gr.Blocks() as demo:
            md_header = self.header()
            with gr.Tabs():
                with gr.TabItem('1: Pose Alignment'):
                    with gr.Row():
                        with gr.Column(scale=3):
                            img_pose_input = gr.Image(label="Input Image", type="filepath", scale=5)
                            vid_dance_input = gr.Video(label="Input Dance Video", max_length=10, scale=5)
                        with gr.Column(scale=3):
                            vid_dance_output = gr.Video(label="Aligned Pose Output", scale=5, interactive=False)
                            vid_dance_output_demo = gr.Video(label="Aligned Pose Output Demo", scale=5)
                        # Rest of the column setup remains the same
                        with gr.Column(scale=3):
                            # Column settings remain the same

                with gr.TabItem('2: MusePose Inference'):
                    with gr.Row():
                        with gr.Column(scale=3):
                            img_musepose_input = gr.Image(label="Input Image", type="filepath", scale=5)
                            vid_pose_input = gr.Video(label="Input Aligned Pose Video", max_length=10, scale=5)
                        with gr.Column(scale=3):
                            vid_output = gr.Video(label="MusePose Output", scale=5)
                            vid_output_demo = gr.Video(label="MusePose Output Demo", scale=5)
                        # Rest of the settings remains the same
            return demo

    @staticmethod
    def header():
        header = gr.HTML(
            """
            <h1 style="font-size: 23px;">
                <a href="https://github.com/jhj0517/MusePose-WebUI" target="_blank">MusePose WebUI</a>
            </h1>
            <p style="font-size: 18px;">
                <strong>Note</strong>: This space now allows video input up to <strong>10 seconds</strong> because ZeroGPU limits the function runtime to 2 minutes. <br>
                If you want longer video inputs, you have to run it locally. Click the link above and follow the README to try it locally.<br><br>
                When you have completed the <strong>1: Pose Alignment</strong> process, go to <strong>2: MusePose Inference</strong> and click the "GENERATE" button.
            </p>
            """
        )
        return header

    def launch(self):
        demo = self.musepose_demo()
        demo.queue().launch(
            share=self.args.share
        )

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default=os.path.join("pretrained_weights"), help='Pretrained models directory for MusePose')
    parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Output directory for the result')
    parser.add_argument('--disable_model_download_at_start', type=bool, default=False, nargs='?', const=True, help='Disable model download at start or not')
    parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio makes sharable link if it is true')
    args = parser.parse_args()

    app = App(args=args)
    app.launch()