File size: 6,418 Bytes
4fb3c5e
 
 
 
 
 
 
 
 
 
 
70d5056
4fb3c5e
 
 
61d3740
 
 
6383bc4
 
 
 
 
 
 
 
 
61d3740
 
 
 
af2a8f5
7326df9
 
 
 
 
 
 
 
 
 
 
 
61d3740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26f3e39
 
61d3740
 
 
 
 
 
 
fce8c40
c63e736
fce8c40
c63e736
 
 
fce8c40
 
 
4fb3c5e
af2a8f5
4fb3c5e
 
 
 
 
 
 
 
26f3e39
4fb3c5e
 
26f3e39
 
c63e736
26f3e39
af2a8f5
26f3e39
61d3740
4fb3c5e
61d3740
 
7326df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61d3740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fb3c5e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#!/usr/bin/env python

from __future__ import annotations

import argparse

import gradio as gr

from model import Model

TITLE = '# Anime Face Generation with [Diffusers](https://github.com/huggingface/diffusers)'
DESCRIPTION = 'Expected execution time on Hugging Face Spaces: 5s (DDIM, 20 steps), 6s (PNDM, 20 steps), 247s (DDPM, 1000 steps)'
FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.diffusers-anime-faces" alt="visitor badge" />'


def create_simple_demo(model: Model) -> gr.Blocks:
    with gr.Blocks() as demo:
        run_button = gr.Button('Generate')
        with gr.Tabs():
            with gr.TabItem('Result (Superresolved)'):
                result = gr.Image(show_label=False, elem_id='result-grid')
            with gr.TabItem('Result (Raw)'):
                result_raw = gr.Image(show_label=False,
                                      elem_id='result-grid-raw')
        run_button.click(fn=model.run_simple,
                         inputs=None,
                         outputs=[result, result_raw])
    return demo


def create_advanced_demo(model: Model) -> gr.Blocks:
    def update_num_steps(name: str) -> dict:
        visible = name != 'DDPM'
        if name == 'PNDM':
            minimum = 4
            maximum = 100
        else:
            minimum = 1
            maximum = 200
        return gr.Slider.update(visible=visible,
                                minimum=minimum,
                                maximum=maximum,
                                value=20)

    with gr.Blocks() as demo:
        gr.Markdown(DESCRIPTION)

        with gr.Row():
            with gr.Column():
                with gr.Group():
                    model_name = gr.Dropdown(model.MODEL_NAMES,
                                             value=model.MODEL_NAMES[0],
                                             label='Model',
                                             interactive=False)
                    scheduler_type = gr.Radio(choices=['DDPM', 'DDIM', 'PNDM'],
                                              value='DDIM',
                                              label='Scheduler')
                    num_steps = gr.Slider(1,
                                          200,
                                          step=1,
                                          value=20,
                                          label='Number of Steps')
                    randomize_seed = gr.Checkbox(value=False,
                                                 label='Randomize Seed')
                    seed = gr.Slider(0,
                                     100000,
                                     step=1,
                                     value=1234,
                                     label='Seed')
                    run_button = gr.Button('Run')
            with gr.Column():
                with gr.Tabs():
                    with gr.TabItem('Result (Superresolved)'):
                        result = gr.Image(show_label=False, elem_id='result')
                    with gr.TabItem('Result (Raw)'):
                        result_raw = gr.Image(show_label=False,
                                              elem_id='result-raw')
                    with gr.TabItem('Denoising Process'):
                        result_video = gr.Video(show_label=False,
                                                elem_id='result-video')

        scheduler_type.change(fn=update_num_steps,
                              inputs=scheduler_type,
                              outputs=num_steps,
                              queue=False)
        run_button.click(fn=model.run,
                         inputs=[
                             model_name,
                             scheduler_type,
                             num_steps,
                             randomize_seed,
                             seed,
                         ],
                         outputs=[
                             result,
                             result_raw,
                             seed,
                             result_video,
                         ])
    return demo


def create_sample_image_view_demo() -> gr.Blocks:
    def get_sample_image_url(file_name: str) -> str:
        sample_image_dir = 'https://huggingface.co/spaces/hysts/diffusers-anime-faces/resolve/main/samples'
        return f'{sample_image_dir}/{file_name}'

    def get_sample_image_markdown(name: str) -> str:
        model_name = name.split()[0]
        if name == 'ddpm-128-exp000 (DDPM)':
            scheduler = 'DDPM'
            steps = 1000
            file_name = f'{model_name}.png'
        elif name == 'ddpm-128-exp000 (DDIM, 20 steps)':
            scheduler = 'DDIM'
            steps = 20
            file_name = f'{model_name}-ddim-20steps.png'
        else:
            raise ValueError
        url = get_sample_image_url(file_name)
        text = f'''
                - size: 128x128
                - seed: 0-99
                - scheduler: {scheduler}
                - steps: {steps}

                ![sample images]({url})'''
        return text

    with gr.Blocks() as demo:
        with gr.Row():
            model_name = gr.Dropdown([
                'ddpm-128-exp000 (DDPM)',
                'ddpm-128-exp000 (DDIM, 20 steps)',
            ],
                                     value='ddpm-128-exp000 (DDPM)',
                                     label='Model')
        with gr.Row():
            text = get_sample_image_markdown(model_name.value)
            sample_images = gr.Markdown(text)

        model_name.change(fn=get_sample_image_markdown,
                          inputs=model_name,
                          outputs=sample_images)
    return demo


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cpu')
    args = parser.parse_args()
    model = Model(args.device)

    with gr.Blocks(css='style.css') as demo:
        gr.Markdown(TITLE)
        with gr.Tabs():
            with gr.TabItem('Simple Mode'):
                create_simple_demo(model)
            with gr.TabItem('Advanced Mode'):
                create_advanced_demo(model)
            with gr.TabItem('Sample Images'):
                create_sample_image_view_demo()
        gr.Markdown(FOOTER)
    demo.launch(enable_queue=True, share=False)


if __name__ == '__main__':
    main()