import gradio as gr import torch from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler from huggingface_hub import hf_hub_download from safetensors.torch import load_file import spaces import os from PIL import Image, ImageFilter from typing import List, Tuple SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1" # Constants base = "stabilityai/stable-diffusion-xl-base-1.0" repo = "ByteDance/SDXL-Lightning" checkpoints = { "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1], "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2], "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4], "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8], } aspect_ratios = { "21:9": (21, 9), "2:1": (2, 1), "16:9": (16, 9), "5:4": (5, 4), "4:3": (4, 3), "3:2": (3, 2), "1:1": (1, 1), } # Function to calculate resolution def calculate_resolution(aspect_ratio, mode='landscape', total_pixels=1024*1024, divisibility=8): if aspect_ratio not in aspect_ratios: raise ValueError(f"Invalid aspect ratio: {aspect_ratio}") width_multiplier, height_multiplier = aspect_ratios[aspect_ratio] ratio = width_multiplier / height_multiplier if mode == 'portrait': # Swap the ratio for portrait mode ratio = 1 / ratio height = int((total_pixels / ratio) ** 0.5) height -= height % divisibility width = int(height * ratio) width -= width % divisibility while width * height > total_pixels: height -= divisibility width = int(height * ratio) width -= width % divisibility return width, height # Example prompts with ckpt, aspect, and mode examples = [ {"prompt": "A futuristic cityscape at sunset", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "16:9", "mode": "landscape"}, {"prompt": "pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"}, {"prompt": "A portrait of a robot in the style of Renaissance art", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"}, {"prompt": "full body of alien shaped like woman, big golden eyes, mars planet, photo, digital art, fantasy", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "1:1", "mode": "portrait"}, {"prompt": "A serene landscape with mountains and a river", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "3:2", "mode": "landscape"}, {"prompt": "post-apocalyptic wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "16:9", "mode": "landscape"} ] # Define a function to set the example inputs def set_example(selected_prompt): # Find the example that matches the selected prompt for example in examples: if example["prompt"] == selected_prompt: return example["prompt"], example["negative_prompt"], example["ckpt"], example["aspect"], example["mode"] return None, None, None, None, None # Default values if not found # Check if CUDA is available (GPU support), and set the appropriate device device = "cuda" if torch.cuda.is_available() else "cpu" # Load the pipeline for the specified device # For GPU, use torch_dtype=torch.float16 for better performance if device == "cuda": pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device) else: pipe = StableDiffusionXLPipeline.from_pretrained(base).to(device) if SAFETY_CHECKER: from safety_checker import StableDiffusionSafetyChecker from transformers import CLIPFeatureExtractor safety_checker = StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker" ).to(device) feature_extractor = CLIPFeatureExtractor.from_pretrained( "openai/clip-vit-base-patch32" ) def check_nsfw_images( images: List[Image.Image] ) -> Tuple[List[Image.Image], List[bool]]: # Assuming feature_extractor and safety_checker are defined and initialized elsewhere # Convert PIL Images to the format expected by the feature extractor # This often involves converting them to tensors, but the exact method # depends on the feature_extractor's requirements safety_checker_inputs = [feature_extractor(image, return_tensors="pt").to("cuda") for image in images] # Get NSFW concepts for each image has_nsfw_concepts = [safety_checker( images=[image], clip_input=safety_checker_input.pixel_values.to("cuda") ) for image, safety_checker_input in zip(images, safety_checker_inputs)] # Flatten the has_nsfw_concepts list if it's nested has_nsfw_concepts = [item for sublist in has_nsfw_concepts for item in sublist] return images, has_nsfw_concepts # Function @spaces.GPU(enable_queue=True) def generate_image(prompt, negative_prompt, ckpt, aspect_ratio, mode): width, height = calculate_resolution(aspect_ratio, mode) # Calculate resolution based on the aspect ratio checkpoint = checkpoints[ckpt][0] num_inference_steps = checkpoints[ckpt][1] if num_inference_steps==1: # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference. pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample") else: # Ensure sampler uses "trailing" timesteps. pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device=device)) results = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=0, width=width, height=height ) if SAFETY_CHECKER: images, has_nsfw_concepts = check_nsfw_images(results.images) if any(has_nsfw_concepts): gr.Warning("NSFW content detected.") # Apply a blur filter to the first image in the results blurred_image = images[0].filter(ImageFilter.GaussianBlur(16)) # Adjust the radius as needed return blurred_image return images[0] return results.images[0] # Gradio Interface description = """ SDXL-Lightning ByteDance model demo. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning """ with gr.Blocks(css="style.css") as demo: gr.HTML("

Text-to-Image with SDXL-Lightning ⚡

") gr.Markdown(description) with gr.Group(): with gr.Row(): prompt = gr.Textbox(label='Enter you image prompt:', scale=8) with gr.Row(): negative_prompt = gr.Textbox(label='Optional negative prompt:', scale=8) with gr.Row(): ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True) aspect = gr.Dropdown(label='Aspect Ratio', choices=list(aspect_ratios.keys()), value='1:1', interactive=True) mode = gr.Dropdown(label='Mode', choices=['landscape', 'portrait'], value='landscape') # Mode as a dropdown submit = gr.Button(scale=1, variant='primary') img = gr.Image(label='SDXL-Lightning Generated Image') prompt.submit(fn=generate_image, inputs=[prompt, negative_prompt, ckpt, aspect, mode], outputs=img, ) submit.click(fn=generate_image, inputs=[prompt, negative_prompt, ckpt, aspect, mode], outputs=img, ) # Dropdown for selecting examples example_dropdown = gr.Dropdown(label='Select an Example', choices=[e["prompt"] for e in examples]) example_dropdown.change(fn=set_example, inputs=example_dropdown, outputs=[prompt, negative_prompt, ckpt, aspect, mode]) demo.queue().launch()