Spaces:
Runtime error
Runtime error
import gradio as gr | |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline | |
import torch | |
from PIL import Image | |
# Initialize the pipelines | |
text2img_pipe = StableDiffusionPipeline.from_pretrained( | |
"stabilityai/sdxl-turbo", | |
torch_dtype=torch.float16 # Removed the revision argument | |
) | |
text2img_pipe.to("cuda") | |
img2img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
"stabilityai/sdxl-turbo", | |
torch_dtype=torch.float16 # Removed the revision argument | |
) | |
img2img_pipe.to("cuda") | |
# Helper function to load LoRA weights (if supported by the model) | |
def load_lora_weights(pipe, model_name, weight_name): | |
try: | |
pipe.load_lora_weights(model_name, weight_name=weight_name) | |
except AttributeError: | |
print(f"LoRA weights cannot be loaded for {model_name}. Skipping...") | |
# Load LoRA weights (if applicable) | |
load_lora_weights(text2img_pipe, "AbdalrhmanRi/SDXL-Turbo-With-AppleVisionPro", "pytorch_lora_weights.safetensors") | |
load_lora_weights(img2img_pipe, "AbdalrhmanRi/SDXL-Turbo-With-AppleVisionPro", "pytorch_lora_weights.safetensors") | |
def generate_image(prompt, init_image): | |
if init_image is None: | |
# Text-to-Image generation | |
output_image = text2img_pipe( | |
prompt=prompt, | |
num_inference_steps=40, | |
guidance_scale=7.5, | |
height=512 | |
).images[0] | |
yield output_image, None | |
# Refining step (if needed) | |
output_refiner_image = text2img_pipe( | |
prompt=prompt, | |
image=output_image, | |
num_inference_steps=40, | |
guidance_scale=1.0, | |
height=512 | |
).images[0] | |
yield output_image, output_refiner_image | |
else: | |
# Image-to-Image generation | |
init_image = init_image.resize((512, 512), Image.LANCZOS) | |
output_image = img2img_pipe( | |
prompt=prompt, | |
image=init_image, | |
num_inference_steps=40, | |
strength=0.75, | |
guidance_scale=7.5, | |
height=512 | |
).images[0] | |
yield output_image, None | |
# Refining step (if needed) | |
output_refiner_image = img2img_pipe( | |
prompt=prompt, | |
image=output_image, | |
num_inference_steps=40, | |
strength=0.75, | |
guidance_scale=1.0, | |
height=512 | |
).images[0] | |
yield output_image, output_refiner_image | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=generate_image, | |
inputs=[ | |
gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here..."), | |
gr.Image(type="pil", label="Initial Image (Optional)", height=360) | |
], | |
outputs=[gr.Image(type="pil", label="Generated Image"), gr.Image(type="pil", label="Refined Image")], | |
live=True, # Add live=True for real-time updates | |
title="Generate Image Using Generative AI", | |
theme=gr.themes.Default(primary_hue="green"), | |
description="Text-to-Image or Image-to-Image Generation with SDXL-Turbo." | |
) | |
# Launch the interface | |
interface.launch() | |