File size: 1,738 Bytes
7f31f24
 
 
 
 
130343e
 
3e53005
130343e
 
8edd85f
9062426
 
f0911ab
98b968b
ac8eaac
8edd85f
7f31f24
5312c2f
ac8eaac
 
b7d30a0
 
5312c2f
473ded3
5312c2f
8edd85f
473ded3
27633e4
b7d30a0
 
5312c2f
27633e4
 
5312c2f
7f31f24
 
 
ac8eaac
27633e4
7f31f24
 
 
 
 
9a2c62c
96a1c07
7f31f24
397605b
7f31f24
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import spaces
import gradio as gr
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline

prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16)

prior_pipeline.enable_model_cpu_offload()
decoder_pipeline.enable_model_cpu_offload()

prior = prior_pipeline#.to("cuda")
decoder = decoder_pipeline#.to("cuda")

@spaces.GPU
def generate(prompt, negative_prompt, steps):
    prior_output = prior(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=1024,
        height=1024,
        guidance_scale=4.0,
        num_images_per_prompt=1,
        num_inference_steps=steps
    )
    
    decoder_output = decoder(
        image_embeddings=prior_output.image_embeddings.to(torch.bfloat16),
        prompt=prompt,
        guidance_scale=0.0,
        output_type="pil",
        num_inference_steps=10,
        negative_prompt=negative_prompt
    ).images[0]
    return decoder_output

with gr.Blocks() as demo:
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", value="A perfectly red apple, 32K HDR, studio lighting")
        generate_btn = gr.Button("Generate")

    with gr.Row():
        output = gr.Image(label="Output")

    with gr.Accordion("Advanced", open=False):
        negative_prompt = gr.Textbox(label="Negative Prompt", value="ugly, low quality")
        steps = gr.Slider(minimum=4, maximum=50, step=1, value=20, label="Steps")

    generate_btn.click(generate, inputs=[prompt, negative_prompt, steps], outputs=output)

demo.launch()