import torch import gradio as gr from PIL import Image import spaces from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline device = "cuda" num_images_per_prompt = 1 prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device) decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device) css = """ footer { visibility: hidden } #generate_button { color: white; border-color: #007bff; background: #2563eb; } #save_button { color: white; border-color: #028b40; background: #01b97c; width: 200px; } #settings_header { background: rgb(245, 105, 105); } """ @spaces.GPU def gen(prompt, negative, width, height): prior_output = prior( prompt=prompt, height=height, width=width, negative_prompt=negative, guidance_scale=4.0, num_images_per_prompt=num_images_per_prompt, num_inference_steps=20 ) decoder_output = decoder( image_embeddings=prior_output.image_embeddings.half(), prompt=prompt, negative_prompt=negative, guidance_scale=0.0, output_type="pil", num_inference_steps=10 ).images return decoder_output with gr.Blocks(css=css) as demo: gr.Markdown("# Stable Cascade ```DEMO```") with gr.Row(): prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=3, lines=1, interactive=True, scale=20) button = gr.Button(value="Generate", scale=1) with gr.Accordion("Advanced options", open=False): with gr.Row(): negative = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=2, lines=1, interactive=True) with gr.Row(): width = gr.Slider(label="Width", minimum=1024, maximum=2048, step=8, value=1024, interactive=True) height = gr.Slider(label="Height", minimum=1024, maximum=2048, step=8, value=1024, interactive=True) with gr.Row(): gallery = gr.Gallery(show_label=False, rows=1, columns=1, allow_preview=True, preview=True) button.click(gen, inputs=[prompt, negative, width, height], outputs=gallery) demo.launch(show_api=False)