File size: 2,398 Bytes
cefcefa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
from argparse import ArgumentParser


def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        '--model_path',
        type=str,
        default='./model/v3.pth',
        help='Path or url to the model file',
    )
    parser.add_argument(
        '--model_type',
        type=str,
        default='mlp',
        help='Type of the model',
    )
    parser.add_argument(
        '--fix_model_path',
        action='store_true',
        help='Fix the model path',
    )
    parser.add_argument(
        '--device',
        type=str,
        default='cuda',
        help='Device to use',
    )
    parser.add_argument(
        '--share',
        action='store_true',
        help='Share the demo',
    )
    return parser.parse_args()


def ui(args):
    from waifu_scorer.predict import WaifuScorer, load_model
    scorer = WaifuScorer(
        model_path=args.model_path,
        model_type=args.model_type,
        device=args.device,
    )

    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                image = gr.Image(
                    label='Image',
                    type='pil',
                    height=512,
                    sources=['upload', 'clipboard'],
                )
            with gr.Column():
                with gr.Row():
                    model_path = gr.Textbox(
                        label='Model Path',
                        value=args.model_path,
                        placeholder='Path or URL to the model file',
                        interactive=not args.fix_model_path,
                    )
                with gr.Row():
                    score = gr.Number(
                        label='Score',
                    )

        def change_model(model_path):
            nonlocal scorer
            scorer.mlp = load_model(model_path, model_type=args.model_type, device=args.device)
            print(f"Model changed to `{model_path}`")
            return gr.update()

        model_path.submit(
            fn=change_model,
            inputs=model_path,
            outputs=model_path,
        )

        image.change(
            fn=lambda image: scorer.predict([image]*2)[0] if image is not None else None,
            inputs=image,
            outputs=score,
        )

    return demo


def launch(args):
    demo = ui(args)
    demo.launch(share=args.share)