AbdalrhmanRi's picture
Update app.py
2343433 verified
raw
history blame
3.07 kB
import gradio as gr
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
import torch
from PIL import Image
# Initialize the pipelines
text2img_pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16 # Removed the revision argument
)
text2img_pipe.to("cuda")
img2img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16 # Removed the revision argument
)
img2img_pipe.to("cuda")
# Helper function to load LoRA weights (if supported by the model)
def load_lora_weights(pipe, model_name, weight_name):
try:
pipe.load_lora_weights(model_name, weight_name=weight_name)
except AttributeError:
print(f"LoRA weights cannot be loaded for {model_name}. Skipping...")
# Load LoRA weights (if applicable)
load_lora_weights(text2img_pipe, "AbdalrhmanRi/SDXL-Turbo-With-AppleVisionPro", "pytorch_lora_weights.safetensors")
load_lora_weights(img2img_pipe, "AbdalrhmanRi/SDXL-Turbo-With-AppleVisionPro", "pytorch_lora_weights.safetensors")
def generate_image(prompt, init_image):
if init_image is None:
# Text-to-Image generation
output_image = text2img_pipe(
prompt=prompt,
num_inference_steps=40,
guidance_scale=7.5,
height=512
).images[0]
yield output_image, None
# Refining step (if needed)
output_refiner_image = text2img_pipe(
prompt=prompt,
image=output_image,
num_inference_steps=40,
guidance_scale=1.0,
height=512
).images[0]
yield output_image, output_refiner_image
else:
# Image-to-Image generation
init_image = init_image.resize((512, 512), Image.LANCZOS)
output_image = img2img_pipe(
prompt=prompt,
image=init_image,
num_inference_steps=40,
strength=0.75,
guidance_scale=7.5,
height=512
).images[0]
yield output_image, None
# Refining step (if needed)
output_refiner_image = img2img_pipe(
prompt=prompt,
image=output_image,
num_inference_steps=40,
strength=0.75,
guidance_scale=1.0,
height=512
).images[0]
yield output_image, output_refiner_image
# Define the Gradio interface
interface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here..."),
gr.Image(type="pil", label="Initial Image (Optional)", height=360)
],
outputs=[gr.Image(type="pil", label="Generated Image"), gr.Image(type="pil", label="Refined Image")],
live=True, # Add live=True for real-time updates
title="Generate Image Using Generative AI",
theme=gr.themes.Default(primary_hue="green"),
description="Text-to-Image or Image-to-Image Generation with SDXL-Turbo."
)
# Launch the interface
interface.launch()