Spaces:
Runtime error
Runtime error
import spaces | |
from diffusers import StableDiffusionXLPipeline | |
from diffusers import DiffusionPipeline | |
from pydantic import BaseModel | |
from PIL import Image | |
import gradio as gr | |
import torch | |
import uuid | |
import io | |
import os | |
# Load the base & refiner pipelines | |
base = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True | |
) | |
base.to("cuda:0") | |
# Load your model | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"segmind/SSD-1B", | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16" | |
) | |
pipe.to("cuda:0") | |
refiner = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-refiner-1.0", | |
text_encoder_2=base.text_encoder_2, | |
vae=base.vae, | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16", | |
) | |
refiner.to("cuda:0") | |
refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True) | |
# Apply the GPU decorator | |
def generate_and_save_image(prompt, negative_prompt=''): | |
# Generate image using the provided prompts | |
image = pipe(prompt=prompt, negative_prompt=negative_prompt).images[0] | |
# Generate a unique UUID for the filename | |
unique_id = str(uuid.uuid4()) | |
image_path = f"generated_images/{unique_id}.jpeg" | |
# Save generated image locally | |
os.makedirs('generated_images', exist_ok=True) | |
image.save(image_path, format='JPEG') | |
# Return the path of the saved image to display in Gradio interface | |
return image_path | |
def generate_image_with_refinement(prompt): | |
n_steps = 40 | |
high_noise_frac = 0.8 | |
# run both experts | |
image = base( | |
prompt=prompt, | |
num_inference_steps=n_steps, | |
denoising_end=high_noise_frac, | |
output_type="latent", | |
).images | |
image = refiner( | |
prompt=prompt, | |
num_inference_steps=n_steps, | |
denoising_start=high_noise_frac, | |
image=image, | |
).images[0] | |
# Save the image as before | |
unique_id = str(uuid.uuid4()) | |
image_path = f"generated_images_refined/{unique_id}.jpeg" | |
os.makedirs('generated_images_refined', exist_ok=True) | |
image.save(image_path, format='JPEG') | |
return image_path | |
# Start of the Gradio Blocks interface | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("# Image Generation with SSD-1B") | |
gr.Markdown("Enter a prompt and (optionally) a negative prompt to generate an image.") | |
# Input fields for positive and negative prompts | |
with gr.Row(): | |
prompt1 = gr.Textbox(label="Enter prompt") | |
negative_prompt = gr.Textbox(label="Enter negative prompt (optional)") | |
# Button for generating the image | |
generate_button1 = gr.Button("Generate Image") | |
# Output image display, set to a larger default size | |
output_image1 = gr.Image(label="Generated Image") | |
# Click event for the generate button | |
generate_button1.click( | |
generate_and_save_image, | |
inputs=[prompt1, negative_prompt], | |
outputs=output_image1 | |
) | |
with gr.Column(): | |
gr.Markdown("## Refined Image Generation") | |
gr.Markdown("Enter a prompt to generate a refined image.") | |
# Input field for the prompt | |
prompt2 = gr.Textbox(label="Enter prompt for refined generation") | |
# Button for generating the refined image | |
generate_button2 = gr.Button("Generate Refined Image") | |
# Output refined image display, set to a larger default size | |
output_image2 = gr.Image(label="Generated Refined Image") | |
# Click event for the generate button | |
generate_button2.click( | |
generate_image_with_refinement, | |
inputs=[prompt2], | |
outputs=output_image2 | |
) | |
# Set the image display to be the largest element for both SSD-1B and refined generation | |
demo.update( | |
output_image1.style(width='100%', height='auto', min_height='400px'), | |
output_image2.style(width='100%', height='auto', min_height='400px') | |
) | |
# Launch the combined Gradio app | |
demo.launch() | |