Stable-Cascade / app.py
ehristoforu's picture
Update app.py
d6d99c2 verified
raw
history blame
No virus
2.29 kB
import torch
import gradio as gr
from PIL import Image
import spaces
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
device = "cuda"
num_images_per_prompt = 1
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device)
css = """
footer {
visibility: hidden
}
#generate_button {
color: white;
border-color: #007bff;
background: #2563eb;
}
#save_button {
color: white;
border-color: #028b40;
background: #01b97c;
width: 200px;
}
#settings_header {
background: rgb(245, 105, 105);
}
"""
@spaces.GPU
def gen(prompt, negative, width, height):
prior_output = prior(
prompt=prompt,
height=height,
width=width,
negative_prompt=negative,
guidance_scale=4.0,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=20
)
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings.half(),
prompt=prompt,
negative_prompt=negative,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=10
).images
return decoder_output
with gr.Blocks(css=css) as demo:
gr.Markdown("# Stable Cascade ```DEMO```")
with gr.Row():
prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=3, lines=1, interactive=True, scale=20)
button = gr.Button(value="Generate", scale=1)
with gr.Accordion("Advanced options", open=False):
with gr.Row():
negative = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=2, lines=1, interactive=True)
with gr.Row():
width = gr.Slider(label="Width", minimum=1024, maximum=2048, step=8, value=1024, interactive=True)
height = gr.Slider(label="Height", minimum=1024, maximum=2048, step=8, value=1024, interactive=True)
with gr.Row():
gallery = gr.Gallery(show_label=False, rows=1, columns=1, allow_preview=True, preview=True)
button.click(gen, inputs=[prompt, negative, width, height], outputs=gallery)
demo.launch(show_api=False)