bokyeong1015's picture
add nparams count
3a45ac7
raw
history blame
4.59 kB
import os
import subprocess
from pathlib import Path
import gradio as gr
import torch
from demo import SdmCompressionDemo
dest_path_config = Path('checkpoints/BK-SDM-Small_iter50000/unet/config.json')
dest_path_torch_ckpt = Path('checkpoints/BK-SDM-Small_iter50000/unet/diffusion_pytorch_model.bin')
BK_SDM_CONFIG_URL: str = os.getenv('BK_SDM_CONFIG_URL', None)
BK_SDM_TORCH_CKPT_URL: str = os.getenv('BK_SDM_TORCH_CKPT_URL', None)
assert BK_SDM_CONFIG_URL is not None
assert BK_SDM_TORCH_CKPT_URL is not None
subprocess.call(
f"wget --no-check-certificate -O {dest_path_config} {BK_SDM_CONFIG_URL}",
shell=True
)
subprocess.call(
f"wget --no-check-certificate -O {dest_path_torch_ckpt} {BK_SDM_TORCH_CKPT_URL}",
shell=True
)
if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'cpu'
servicer = SdmCompressionDemo(device)
example_list = servicer.get_example_list()
with gr.Blocks(theme='nota-ai/theme') as demo:
gr.Markdown(Path('docs/header.md').read_text())
gr.Markdown(Path('docs/description.md').read_text())
with gr.Row():
with gr.Column(variant='panel', scale=30):
text = gr.Textbox(label="Input Prompt", max_lines=5, placeholder="Enter your prompt")
with gr.Row().style(equal_height=True):
generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
negative = gr.Textbox(label=f'Negative Prompt', placeholder=f'Enter aspects to remove (e.g., {"low quality"})')
with gr.Row():
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, minimum=4, maximum=11, step=0.5)
steps = gr.Slider(label="Denoising Steps", value=25, minimum=10, maximum=75, step=5)
seed = gr.Slider(0, 999999, label='Random Seed', value=1234, step=1)
with gr.Tab("Example Prompts"):
examples = gr.Examples(examples=example_list, inputs=[text])
with gr.Column(variant='panel',scale=35):
# Define original model output components
gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
original_model_output = gr.Image(label="Original Model")
with gr.Row().style(equal_height=True):
with gr.Column():
original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
original_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_original), label="# Parameters")
original_model_error = gr.Markdown()
with gr.Column(variant='panel',scale=35):
# Define compressed model output components
gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
compressed_model_output = gr.Image(label="Compressed Model")
with gr.Row().style(equal_height=True):
with gr.Column():
compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
compressed_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_compressed), label="# Parameters")
compressed_model_error = gr.Markdown()
inputs = [text, negative, guidance_scale, steps, seed]
# Click the generate button for original model
original_model_outputs = [original_model_output, original_model_error, original_model_test_time]
text.submit(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
generate_original_button.click(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
# Click the generate button for compressed model
compressed_model_outputs = [compressed_model_output, compressed_model_error, compressed_model_test_time]
text.submit(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
generate_compressed_button.click(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
gr.Markdown(Path('docs/footer.md').read_text())
demo.queue(concurrency_count=1)
# demo.launch()
demo.launch()