Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
import spaces | |
import os | |
from PIL import Image, ImageFilter | |
from typing import List, Tuple | |
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1" | |
# Constants | |
base = "stabilityai/stable-diffusion-xl-base-1.0" | |
repo = "ByteDance/SDXL-Lightning" | |
checkpoints = { | |
"1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1], | |
"2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2], | |
"4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4], | |
"8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8], | |
} | |
aspect_ratios = { | |
"21:9": (21, 9), | |
"2:1": (2, 1), | |
"16:9": (16, 9), | |
"5:4": (5, 4), | |
"4:3": (4, 3), | |
"3:2": (3, 2), | |
"1:1": (1, 1), | |
} | |
# Function to calculate resolution | |
def calculate_resolution(aspect_ratio, mode='landscape', total_pixels=1024*1024, divisibility=8): | |
if aspect_ratio not in aspect_ratios: | |
raise ValueError(f"Invalid aspect ratio: {aspect_ratio}") | |
width_multiplier, height_multiplier = aspect_ratios[aspect_ratio] | |
ratio = width_multiplier / height_multiplier | |
if mode == 'portrait': | |
# Swap the ratio for portrait mode | |
ratio = 1 / ratio | |
height = int((total_pixels / ratio) ** 0.5) | |
height -= height % divisibility | |
width = int(height * ratio) | |
width -= width % divisibility | |
while width * height > total_pixels: | |
height -= divisibility | |
width = int(height * ratio) | |
width -= width % divisibility | |
return width, height | |
# Example prompts with ckpt, aspect, and mode | |
examples = [ | |
{"prompt": "A futuristic cityscape at sunset", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "16:9", "mode": "landscape"}, | |
{"prompt": "pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"}, | |
{"prompt": "A portrait of a robot in the style of Renaissance art", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"}, | |
{"prompt": "full body of alien shaped like woman, big golden eyes, mars planet, photo, digital art, fantasy", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "1:1", "mode": "portrait"}, | |
{"prompt": "A serene landscape with mountains and a river", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "3:2", "mode": "landscape"}, | |
{"prompt": "post-apocalyptic wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "16:9", "mode": "landscape"} | |
] | |
# Define a function to set the example inputs | |
def set_example(selected_prompt): | |
# Find the example that matches the selected prompt | |
for example in examples: | |
if example["prompt"] == selected_prompt: | |
return example["prompt"], example["negative_prompt"], example["ckpt"], example["aspect"], example["mode"] | |
return None, None, None, None, None # Default values if not found | |
# Check if CUDA is available (GPU support), and set the appropriate device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the pipeline for the specified device | |
# For GPU, use torch_dtype=torch.float16 for better performance | |
if device == "cuda": | |
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device) | |
else: | |
pipe = StableDiffusionXLPipeline.from_pretrained(base).to(device) | |
if SAFETY_CHECKER: | |
from safety_checker import StableDiffusionSafetyChecker | |
from transformers import CLIPFeatureExtractor | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker" | |
).to(device) | |
feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
"openai/clip-vit-base-patch32" | |
) | |
def check_nsfw_images( | |
images: List[Image.Image] | |
) -> Tuple[List[Image.Image], List[bool]]: | |
# Assuming feature_extractor and safety_checker are defined and initialized elsewhere | |
# Convert PIL Images to the format expected by the feature extractor | |
# This often involves converting them to tensors, but the exact method | |
# depends on the feature_extractor's requirements | |
safety_checker_inputs = [feature_extractor(image, return_tensors="pt").to("cuda") for image in images] | |
# Get NSFW concepts for each image | |
has_nsfw_concepts = [safety_checker( | |
images=[image], | |
clip_input=safety_checker_input.pixel_values.to("cuda") | |
) for image, safety_checker_input in zip(images, safety_checker_inputs)] | |
# Flatten the has_nsfw_concepts list if it's nested | |
has_nsfw_concepts = [item for sublist in has_nsfw_concepts for item in sublist] | |
return images, has_nsfw_concepts | |
# Function | |
def generate_image(prompt, negative_prompt, ckpt, aspect_ratio, mode): | |
width, height = calculate_resolution(aspect_ratio, mode) # Calculate resolution based on the aspect ratio | |
checkpoint = checkpoints[ckpt][0] | |
num_inference_steps = checkpoints[ckpt][1] | |
if num_inference_steps==1: | |
# Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference. | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample") | |
else: | |
# Ensure sampler uses "trailing" timesteps. | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") | |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device=device)) | |
results = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=0, width=width, height=height ) | |
if SAFETY_CHECKER: | |
images, has_nsfw_concepts = check_nsfw_images(results.images) | |
if any(has_nsfw_concepts): | |
gr.Warning("NSFW content detected.") | |
# Apply a blur filter to the first image in the results | |
blurred_image = images[0].filter(ImageFilter.GaussianBlur(16)) # Adjust the radius as needed | |
return blurred_image | |
return images[0] | |
return results.images[0] | |
# Gradio Interface | |
description = """ | |
SDXL-Lightning ByteDance model demo. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>") | |
gr.Markdown(description) | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox(label='Enter you image prompt:', scale=8) | |
with gr.Row(): | |
negative_prompt = gr.Textbox(label='Optional negative prompt:', scale=8) | |
with gr.Row(): | |
ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True) | |
aspect = gr.Dropdown(label='Aspect Ratio', choices=list(aspect_ratios.keys()), value='1:1', interactive=True) | |
mode = gr.Dropdown(label='Mode', choices=['landscape', 'portrait'], value='landscape') # Mode as a dropdown | |
submit = gr.Button(scale=1, variant='primary') | |
img = gr.Image(label='SDXL-Lightning Generated Image') | |
prompt.submit(fn=generate_image, | |
inputs=[prompt, negative_prompt, ckpt, aspect, mode], | |
outputs=img, | |
) | |
submit.click(fn=generate_image, | |
inputs=[prompt, negative_prompt, ckpt, aspect, mode], | |
outputs=img, | |
) | |
# Dropdown for selecting examples | |
example_dropdown = gr.Dropdown(label='Select an Example', choices=[e["prompt"] for e in examples]) | |
example_dropdown.change(fn=set_example, inputs=example_dropdown, outputs=[prompt, negative_prompt, ckpt, aspect, mode]) | |
demo.queue().launch() | |