import gradio as gr import torch from diffusers import StableDiffusionXLPipeline from diffusers.schedulers import TCDScheduler import spaces from PIL import Image SAFETY_CHECKER = True # Constants base = "stabilityai/stable-diffusion-xl-base-1.0" repo = "ByteDance/SDXL-Lightning" checkpoints = { "2-Step": ["pcm_sdxl_smallcfg_2step_converted.safetensors", 2, 0.0], "4-Step": ["pcm_sdxl_smallcfg_4step_converted.safetensors", 4, 0.0], "8-Step": ["pcm_sdxl_smallcfg_8step_converted.safetensors", 8, 0.0], "16-Step": ["pcm_sdxl_smallcfg_16step_converted.safetensors", 16, 0.0], "Normal CFG 4-Step": ["pcm_sdxl_normalcfg_4step_converted.safetensors", 4, 7.5], "Normal CFG 8-Step": ["pcm_sdxl_normalcfg_8step_converted.safetensors", 8, 7.5], "Normal CFG 16-Step": ["pcm_sdxl_normalcfg_16step_converted.safetensors", 16, 7.5], "LCM-Like LoRA": ["pcm_sdxl_lcmlike_lora_converted.safetensors", 16, 0.0], } loaded = None # Ensure model and scheduler are initialized in GPU-enabled function if torch.cuda.is_available(): pipe = StableDiffusionXLPipeline.from_pretrained( base, torch_dtype=torch.float16, variant="fp16" ).to("cuda") if SAFETY_CHECKER: from safety_checker import StableDiffusionSafetyChecker from transformers import CLIPFeatureExtractor safety_checker = StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker" ).to("cuda") feature_extractor = CLIPFeatureExtractor.from_pretrained( "openai/clip-vit-base-patch32" ) def check_nsfw_images( images: list[Image.Image], ) -> tuple[list[Image.Image], list[bool]]: safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda") has_nsfw_concepts = safety_checker( images=[images], clip_input=safety_checker_input.pixel_values.to("cuda") ) return images, has_nsfw_concepts # Function @spaces.GPU(enable_queue=True) def generate_image(prompt, ckpt): global loaded print(prompt, ckpt) checkpoint = checkpoints[ckpt][0] num_inference_steps = checkpoints[ckpt][1] guidance_scale = checkpoints[ckpt][2] if loaded != num_inference_steps: pipe.scheduler = TCDScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", timestep_spacing="trailing", ) pipe.load_lora_weights( "wangfuyun/PCM_Weights", weight_name=checkpoint, subfolder="sdxl" ) loaded = num_inference_steps results = pipe( prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale ) if SAFETY_CHECKER: images, has_nsfw_concepts = check_nsfw_images(results.images) if any(has_nsfw_concepts): gr.Warning("NSFW content detected.") return Image.new("RGB", (512, 512)) return images[0] return results.images[0] # Gradio Interface css = """ .gradio-container { max-width: 60rem !important; } """ with gr.Blocks(css=css) as demo: gr.HTML("

SDXL-Lightning ⚡

") gr.HTML( "

Lightning-fast text-to-image generation

https://huggingface.co/ByteDance/SDXL-Lightning

" ) with gr.Group(): with gr.Row(): prompt = gr.Textbox(label="Enter your prompt (English)", scale=8) ckpt = gr.Dropdown( label="Select inference steps", choices=list(checkpoints.keys()), value="4-Step", interactive=True, ) submit = gr.Button(scale=1, variant="primary") img = gr.Image(label="SDXL-Lightning Generated Image") prompt.submit( fn=generate_image, inputs=[prompt, ckpt], outputs=img, ) submit.click( fn=generate_image, inputs=[prompt, ckpt], outputs=img, ) demo.queue().launch()