File size: 4,589 Bytes
e8ce90b
 
5c762ce
e8ce90b
 
9480716
e8ce90b
5c762ce
 
e8ce90b
 
690f7ad
 
d39bc83
 
db4b904
d39bc83
 
 
 
 
 
 
 
e8ce90b
5c762ce
33a2e0a
 
5c762ce
 
 
 
 
 
c3014da
5c762ce
 
c3014da
5c762ce
 
 
 
 
c3014da
 
5c762ce
 
 
 
 
 
c3014da
3a45ac7
5c762ce
 
 
 
3a45ac7
 
 
5c762ce
3a45ac7
5c762ce
3a45ac7
5c762ce
 
3a45ac7
5c762ce
3a45ac7
 
 
5c762ce
 
c3014da
 
 
 
 
5c762ce
c3014da
5c762ce
 
c3014da
5c762ce
 
 
 
 
 
db4b904
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import subprocess
from pathlib import Path

import gradio as gr
import torch

from demo import SdmCompressionDemo

dest_path_config = Path('checkpoints/BK-SDM-Small_iter50000/unet/config.json')
dest_path_torch_ckpt = Path('checkpoints/BK-SDM-Small_iter50000/unet/diffusion_pytorch_model.bin')
BK_SDM_CONFIG_URL: str = os.getenv('BK_SDM_CONFIG_URL', None)
BK_SDM_TORCH_CKPT_URL: str = os.getenv('BK_SDM_TORCH_CKPT_URL', None)
assert BK_SDM_CONFIG_URL is not None
assert BK_SDM_TORCH_CKPT_URL is not None

subprocess.call(
    f"wget --no-check-certificate -O {dest_path_config} {BK_SDM_CONFIG_URL}",
    shell=True
)
subprocess.call(
    f"wget --no-check-certificate -O {dest_path_torch_ckpt} {BK_SDM_TORCH_CKPT_URL}",
    shell=True
)

if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    servicer = SdmCompressionDemo(device)
    example_list = servicer.get_example_list()

    with gr.Blocks(theme='nota-ai/theme') as demo:
        gr.Markdown(Path('docs/header.md').read_text())
        gr.Markdown(Path('docs/description.md').read_text())
        with gr.Row():
            with gr.Column(variant='panel', scale=30):

                text = gr.Textbox(label="Input Prompt", max_lines=5, placeholder="Enter your prompt")

                with gr.Row().style(equal_height=True):
                    generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
                    generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")

                with gr.Accordion("Advanced Settings", open=False):
                    negative = gr.Textbox(label=f'Negative Prompt', placeholder=f'Enter aspects to remove (e.g., {"low quality"})')
                    with gr.Row():
                        guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, minimum=4, maximum=11, step=0.5)
                        steps = gr.Slider(label="Denoising Steps", value=25, minimum=10, maximum=75, step=5)
                        seed = gr.Slider(0, 999999, label='Random Seed', value=1234, step=1)

                with gr.Tab("Example Prompts"):
                    examples = gr.Examples(examples=example_list, inputs=[text])

            with gr.Column(variant='panel',scale=35):
                # Define original model output components
                gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
                original_model_output = gr.Image(label="Original Model")
                with gr.Row().style(equal_height=True):
                    with gr.Column():
                        original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
                        original_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_original), label="# Parameters")
                    original_model_error = gr.Markdown()
                

            with gr.Column(variant='panel',scale=35):
                # Define compressed model output components
                gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
                compressed_model_output = gr.Image(label="Compressed Model")                
                with gr.Row().style(equal_height=True):
                    with gr.Column():
                        compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
                        compressed_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_compressed), label="# Parameters")
                    compressed_model_error = gr.Markdown()

        inputs = [text, negative, guidance_scale, steps, seed]

        # Click the generate button for original model
        original_model_outputs = [original_model_output, original_model_error, original_model_test_time]
        text.submit(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
        generate_original_button.click(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)

        # Click the generate button for compressed model
        compressed_model_outputs = [compressed_model_output, compressed_model_error, compressed_model_test_time]
        text.submit(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
        generate_compressed_button.click(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)

        gr.Markdown(Path('docs/footer.md').read_text())

    demo.queue(concurrency_count=1)
    # demo.launch()
    demo.launch()