bokyeong1015's picture
remove unused file paths (because of the checkpoint release)
2c99f9a
raw history blame
No virus
3.97 kB
import os
import subprocess
from pathlib import Path
import gradio as gr
import torch
from demo import SdmCompressionDemo
if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'cpu'
servicer = SdmCompressionDemo(device)
example_list = servicer.get_example_list()
with gr.Blocks(theme='nota-ai/theme') as demo:
gr.Markdown(Path('docs/header.md').read_text())
gr.Markdown(Path('docs/description.md').read_text())
with gr.Row():
with gr.Column(variant='panel', scale=30):
text = gr.Textbox(label="Input Prompt", max_lines=5, placeholder="Enter your prompt")
with gr.Row().style(equal_height=True):
generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
negative = gr.Textbox(label=f'Negative Prompt', placeholder=f'Enter aspects to remove (e.g., {"low quality"})')
with gr.Row():
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, minimum=4, maximum=11, step=0.5)
steps = gr.Slider(label="Denoising Steps", value=25, minimum=10, maximum=75, step=5)
seed = gr.Slider(0, 999999, label='Random Seed', value=1234, step=1)
with gr.Tab("Example Prompts"):
examples = gr.Examples(examples=example_list, inputs=[text])
with gr.Column(variant='panel',scale=35):
# Define original model output components
gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
original_model_output = gr.Image(label="Original Model")
with gr.Row().style(equal_height=True):
with gr.Column():
original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
original_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_original), label="# Parameters")
original_model_error = gr.Markdown()
with gr.Column(variant='panel',scale=35):
# Define compressed model output components
gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
compressed_model_output = gr.Image(label="Compressed Model")
with gr.Row().style(equal_height=True):
with gr.Column():
compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
compressed_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_compressed), label="# Parameters")
compressed_model_error = gr.Markdown()
inputs = [text, negative, guidance_scale, steps, seed]
# Click the generate button for original model
original_model_outputs = [original_model_output, original_model_error, original_model_test_time]
text.submit(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
generate_original_button.click(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
# Click the generate button for compressed model
compressed_model_outputs = [compressed_model_output, compressed_model_error, compressed_model_test_time]
text.submit(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
generate_compressed_button.click(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
gr.Markdown(Path('docs/footer.md').read_text())
demo.queue(concurrency_count=1)
# demo.launch()
demo.launch()