bokyeong1015's picture
add nparams count
3a45ac7
raw history blame
No virus
4.59 kB
import os
import subprocess
from pathlib import Path
import gradio as gr
import torch
from demo import SdmCompressionDemo
dest_path_config = Path('checkpoints/BK-SDM-Small_iter50000/unet/config.json')
dest_path_torch_ckpt = Path('checkpoints/BK-SDM-Small_iter50000/unet/diffusion_pytorch_model.bin')
BK_SDM_CONFIG_URL: str = os.getenv('BK_SDM_CONFIG_URL', None)
BK_SDM_TORCH_CKPT_URL: str = os.getenv('BK_SDM_TORCH_CKPT_URL', None)
assert BK_SDM_CONFIG_URL is not None
assert BK_SDM_TORCH_CKPT_URL is not None
subprocess.call(
f"wget --no-check-certificate -O {dest_path_config} {BK_SDM_CONFIG_URL}",
shell=True
)
subprocess.call(
f"wget --no-check-certificate -O {dest_path_torch_ckpt} {BK_SDM_TORCH_CKPT_URL}",
shell=True
)
if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'cpu'
servicer = SdmCompressionDemo(device)
example_list = servicer.get_example_list()
with gr.Blocks(theme='nota-ai/theme') as demo:
gr.Markdown(Path('docs/header.md').read_text())
gr.Markdown(Path('docs/description.md').read_text())
with gr.Row():
with gr.Column(variant='panel', scale=30):
text = gr.Textbox(label="Input Prompt", max_lines=5, placeholder="Enter your prompt")
with gr.Row().style(equal_height=True):
generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
negative = gr.Textbox(label=f'Negative Prompt', placeholder=f'Enter aspects to remove (e.g., {"low quality"})')
with gr.Row():
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, minimum=4, maximum=11, step=0.5)
steps = gr.Slider(label="Denoising Steps", value=25, minimum=10, maximum=75, step=5)
seed = gr.Slider(0, 999999, label='Random Seed', value=1234, step=1)
with gr.Tab("Example Prompts"):
examples = gr.Examples(examples=example_list, inputs=[text])
with gr.Column(variant='panel',scale=35):
# Define original model output components
gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
original_model_output = gr.Image(label="Original Model")
with gr.Row().style(equal_height=True):
with gr.Column():
original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
original_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_original), label="# Parameters")
original_model_error = gr.Markdown()
with gr.Column(variant='panel',scale=35):
# Define compressed model output components
gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
compressed_model_output = gr.Image(label="Compressed Model")
with gr.Row().style(equal_height=True):
with gr.Column():
compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
compressed_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_compressed), label="# Parameters")
compressed_model_error = gr.Markdown()
inputs = [text, negative, guidance_scale, steps, seed]
# Click the generate button for original model
original_model_outputs = [original_model_output, original_model_error, original_model_test_time]
text.submit(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
generate_original_button.click(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
# Click the generate button for compressed model
compressed_model_outputs = [compressed_model_output, compressed_model_error, compressed_model_test_time]
text.submit(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
generate_compressed_button.click(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
gr.Markdown(Path('docs/footer.md').read_text())
demo.queue(concurrency_count=1)
# demo.launch()
demo.launch()