Spaces:
Running
Running
import torch, os | |
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline | |
import gradio as gr | |
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda") | |
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda") | |
def generate_images( | |
prompt="a photo of a girl", | |
negative_prompt="bad,ugly,deformed", | |
height=1024, | |
width=1024, | |
guidance_scale=4.0, | |
seed=42, | |
num_images_per_prompt=1, | |
prior_inference_steps=20, | |
decoder_inference_steps=10 | |
): | |
""" | |
Generates images based on a given prompt using Stable Diffusion models on CUDA device. | |
Parameters: | |
- prompt (str): The prompt to generate images for. | |
- negative_prompt (str): The negative prompt to guide image generation away from. | |
- height (int): The height of the generated images. | |
- width (int): The width of the generated images. | |
- guidance_scale (float): The scale of guidance for the image generation. | |
- prior_inference_steps (int): The number of inference steps for the prior model. | |
- decoder_inference_steps (int): The number of inference steps for the decoder model. | |
Returns: | |
- List[PIL.Image]: A list of generated PIL Image objects. | |
""" | |
generator = torch.Generator(device="cuda").manual_seed(int(seed)) | |
# Generate image embeddings using the prior model | |
prior_output = prior( | |
prompt=prompt, | |
generator=generator, | |
height=height, | |
width=width, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=num_images_per_prompt, | |
num_inference_steps=prior_inference_steps | |
) | |
# Generate images using the decoder model and the embeddings from the prior model | |
decoder_output = decoder( | |
image_embeddings=prior_output.image_embeddings.half(), | |
prompt=prompt, | |
generator=generator, | |
negative_prompt=negative_prompt, | |
guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior | |
output_type="pil", | |
num_inference_steps=decoder_inference_steps | |
).images | |
return decoder_output | |
def web_demo(): | |
with gr.Blocks(): | |
with gr.Row(): | |
with gr.Column(): | |
text2image_prompt = gr.Textbox( | |
lines=1, | |
placeholder="Prompt", | |
show_label=False, | |
) | |
text2image_negative_prompt = gr.Textbox( | |
lines=1, | |
placeholder="Negative Prompt", | |
show_label=False, | |
) | |
text2image_seed = gr.Number( | |
value=42, | |
label="Seed", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
text2image_num_images_per_prompt = gr.Slider( | |
minimum=1, | |
maximum=2, | |
step=1, | |
value=1, | |
label="Number Image", | |
) | |
text2image_height = gr.Slider( | |
minimum=128, | |
maximum=1024, | |
step=32, | |
value=1024, | |
label="Image Height", | |
) | |
text2image_width = gr.Slider( | |
minimum=128, | |
maximum=1024, | |
step=32, | |
value=1024, | |
label="Image Width", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
text2image_guidance_scale = gr.Slider( | |
minimum=0.1, | |
maximum=15, | |
step=0.1, | |
value=4.0, | |
label="Guidance Scale", | |
) | |
text2image_prior_inference_step = gr.Slider( | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=20, | |
label="Prior Inference Step", | |
) | |
text2image_decoder_inference_step = gr.Slider( | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=10, | |
label="Decoder Inference Step", | |
) | |
text2image_predict = gr.Button(value="Generate Image") | |
with gr.Column(): | |
output_image = gr.Gallery( | |
label="Generated images", | |
show_label=False, | |
elem_id="gallery", | |
).style(grid=(1, 2), height=300) | |
text2image_predict.click( | |
fn=generate_images, | |
inputs=[ | |
text2image_prompt, | |
text2image_negative_prompt, | |
text2image_height, | |
text2image_width, | |
text2image_guidance_scale, | |
text2image_seed, | |
text2image_num_images_per_prompt, | |
text2image_prior_inference_step, | |
text2image_decoder_inference_step | |
], | |
outputs=output_image, | |
) |