File size: 5,239 Bytes
b0f3145
 
 
8425644
6619b46
 
 
b0f3145
 
0b120d5
b0f3145
84c6537
 
 
 
88dc089
b0f3145
 
 
8425644
 
 
 
b0f3145
 
23f3ac6
7c672bb
23f3ac6
 
7c672bb
23f3ac6
b0f3145
6619b46
 
 
 
 
edf024c
 
 
23f3ac6
b0f3145
75859e2
c811b57
7c672bb
b0f3145
edf024c
6619b46
0b120d5
 
1db955a
c811b57
7c672bb
f1e3c7d
 
8425644
4f9929e
1db955a
5db2f57
6619b46
 
 
 
 
 
 
0b120d5
6619b46
 
0b120d5
6619b46
b0f3145
7c672bb
7e06c4d
 
88dc089
7e06c4d
 
b0f3145
0ce0e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0f3145
a364bc6
b0f3145
 
75859e2
b0f3145
 
 
 
75859e2
b0f3145
 
 
75859e2
 
 
 
bef93f3
 
 
75859e2
bef93f3
 
 
 
 
 
75859e2
 
 
 
 
bef93f3
 
bbd2321
bef93f3
b0f3145
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import torch
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.image_processor import VaeImageProcessor
from transformers import CLIPImageProcessor
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image

assert torch.cuda.is_available()

device = "cuda"
dtype = torch.float16

base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
opts = {
    "1 Step"  : ("sdxl_lightning_1step_unet_x0.safetensors", 1),
    "2 Steps" : ("sdxl_lightning_2step_unet.safetensors", 2),
    "4 Steps" : ("sdxl_lightning_4step_unet.safetensors", 4),
    "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
}

# Default to load 4-step model.
step_loaded = 4
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0])))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

# Safety checker.
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device, dtype)
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
image_processor = VaeImageProcessor(vae_scale_factor=8)

with open("filter.txt") as f:
    filter_words = {word for word in f.read().split("\n") if word}

# Inference function.
@spaces.GPU(enable_queue=True)
def generate(prompt, option, progress=gr.Progress()):
    global step_loaded
    print(prompt, option)
    ckpt, step = opts[option]
    if any(word in prompt for word in filter_words):
        gr.Warning("Safety checker triggered. Image may contain violent or sexual content.")
        print(f"Safety checker triggered on prompt: {prompt}")
        return Image.new("RGB", (512, 512))
    progress((0, step))
    if step != step_loaded:
        print(f"Switching checkpoint from {step_loaded} to {step}")
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
        pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
        step_loaded = step
    def inference_callback(p, i, t, kwargs):
        progress((i+1, step))
        return kwargs
    results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback, output_type="pil")

    # Safety check
    feature_extractor_input = image_processor.postprocess(results.images, output_type="pil")
    safety_checker_input = feature_extractor(feature_extractor_input, return_tensors="pt")
    images, has_nsfw_concept = safety_checker(
        images=results.images, clip_input=safety_checker_input.pixel_values.to(device, dtype)
    )
    if has_nsfw_concept[0]:
        gr.Warning("Safety checker triggered. Image may contain violent or sexual content.")
        print(f"Safety checker triggered on prompt: {prompt}")
    return images[0]

with gr.Blocks(css="style.css") as demo:
    gr.HTML(
        "<h1><center>SDXL-Lightning</center></h1>" +
        "<p><center>Lightning-fast text-to-image generation</center></p>" +
        "<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
    )
    
    with gr.Row():
        prompt = gr.Textbox(
            label="Text prompt",
            scale=8
        )
        option = gr.Dropdown(
            label="Inference steps",
            choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
            value="4 Steps",
            interactive=True
        )
        submit = gr.Button(
            scale=1,
            variant="primary"
        )
    
    img = gr.Image(label="SDXL-Lightning Generated Image")

    prompt.submit(
        fn=generate,
        inputs=[prompt, option],
        outputs=img,
    )
    submit.click(
        fn=generate,
        inputs=[prompt, option],
        outputs=img,
    )

    gr.Examples(
        fn=generate,
        examples=[
            ["An owl perches quietly on a twisted branch deep within an ancient forest.", "1 Step"],
            ["A lion in the galaxy, octane render", "2 Steps"],
            ["A dolphin leaps through the waves, set against a backdrop of bright blues and teal hues.", "2 Steps"],
            ["A girl smiling", "4 Steps"],
            ["An astronaut riding a horse", "4 Steps"],
            ["A fish on a bicycle, colorful art", "4 Steps"],
            ["A close-up of an Asian lady with sunglasses.", "4 Steps"],
            ["Man portrait, ethereal", "8 Steps"],
            ["Rabbit portrait in a forest, fantasy", "8 Steps"],
            ["A panda swimming", "8 Steps"],
        ],
        inputs=[prompt, option],
        outputs=img,
        cache_examples=True,
    )

    gr.HTML(
        "<p><small><center>This demo is built together by the community</center></small></p>"
    )
    
demo.queue().launch()