SDXL-Lightning / app.py
ysharma's picture
ysharma HF staff
feature for selecting inference steps
f679e15 verified
raw history blame
No virus
2.68 kB
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
import spaces
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
"1-Step" : ["sdxl_lightning_1step_unet_x0.pth", 1],
"2-Step" : ["sdxl_lightning_2step_unet.pth", 2],
"4-Step" : ["sdxl_lightning_4step_unet.pth", 4],
"8-Step" : ["sdxl_lightning_8step_unet.pth", 8],
}
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
# Function
@spaces.GPU(enable_queue=True)
def generate_image(prompt, ckpt):
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if num_inference_steps==1:
# Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
else:
# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint), map_location="cuda"))
image = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0).images[0]
return image
# Gradio Interface
description = """
This demo utilizes the SDXL-Lightning model by ByteDance, which is a fast text-to-image generative model capable of producing high-quality images in 4 steps.
As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML("<h1><center>Text-to-Image with SDXL Lightning ⚡</center></h1>")
gr.Markdown(description)
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
ckpt = gr.Dropdown(label='Select Inference Steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
submit = gr.Button(scale=1, variant='primary')
img = gr.Image(label='SDXL-Lightening Generate Image')
prompt.submit(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
submit.click(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
demo.queue().launch()