sd-kiwi / app.py
jadechoghari's picture
Update app.py
5066ef6 verified
import spaces
from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image
import torch
import os
import gradio as gr
import time
import math
from PIL import Image
import numpy as np
try:
import intel_extension_for_pytorch as ipex
except:
pass
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Device management based on available hardware
device = torch.device(
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
torch_device = device
torch_dtype = torch.float16 if device == "cuda" else torch.float32
print(f"Device: {device}")
print(f"Safety Checker: {SAFETY_CHECKER}")
print(f"Torch Compile: {TORCH_COMPILE}")
# Loading model pipelines
if SAFETY_CHECKER == "True":
i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
t2i_pipe = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
else:
i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sdxl-turbo",
safety_checker=None,
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
t2i_pipe = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
safety_checker=None,
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
# Method for Kiwi model handling
@spaces.GPU()
def kiwi_process(prompt, seed=123123, width=512, height=512):
"""
Custom Kiwi method for image generation using advanced AI techniques.
"""
print(f"Generating Kiwi-style image for prompt: {prompt}")
generator = torch.manual_seed(seed)
result = t2i_pipe(
prompt=prompt,
generator=generator,
num_inference_steps=25, # Using more steps for finer results
guidance_scale=7.5, # More refined guidance
width=width,
height=height,
output_type="pil",
)
return result.images[0]
# Resize image helper
def resize_crop(image, size=512):
image = image.convert("RGB")
w, h = image.size
image = image.resize((size, int(size * (h / w))), Image.BICUBIC)
return image
# Main prediction method
async def predict(init_image, prompt, strength, steps, seed=123123):
if init_image is not None:
init_image = resize_crop(init_image)
generator = torch.manual_seed(seed)
results = i2i_pipe(
prompt=prompt,
image=init_image,
generator=generator,
num_inference_steps=steps,
guidance_scale=0.0,
strength=strength,
width=512,
height=512,
output_type="pil",
)
else:
return kiwi_process(prompt, seed) # Using the Kiwi method for text-to-image
# Gradio UI with a custom description for Kiwi
css = """
#container{
margin: 0 auto;
max-width: 80rem;
}
#intro{
max-width: 100%;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
init_image_state = gr.State()
with gr.Column(elem_id="container"):
gr.Markdown(
"""# Kiwi Image Generator Demo
## Harnessing the Power of Kiwi AI
This demo integrates the Kiwi AI model to generate high-quality images using cutting-edge techniques like quantization and pruning.
""",
elem_id="intro",
)
with gr.Row():
prompt = gr.Textbox(
placeholder="Insert your prompt for Kiwi here:",
scale=5,
container=False,
)
generate_bt = gr.Button("Generate with Kiwi", scale=1)
with gr.Row():
with gr.Column():
image_input = gr.Image(
sources=["upload", "webcam", "clipboard"],
label="Upload or Capture Image",
type="pil",
)
with gr.Column():
image = gr.Image(type="filepath")
with gr.Accordion("Advanced options", open=False):
strength = gr.Slider(
label="Strength",
value=0.7,
minimum=0.0,
maximum=1.0,
step=0.001,
)
steps = gr.Slider(
label="Steps", value=25, minimum=1, maximum=50, step=1
)
seed = gr.Slider(
randomize=True,
minimum=0,
maximum=12013012031030,
label="Seed",
step=1,
)
inputs = [image_input, prompt, strength, steps, seed]
generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
prompt.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
strength.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
demo.queue()
demo.launch()