File size: 4,495 Bytes
b0f3145
 
 
8425644
6619b46
 
 
b0f3145
 
 
84c6537
 
88dc089
b0f3145
 
 
8425644
 
 
 
b0f3145
 
23f3ac6
59f3984
75859e2
7c672bb
b0f3145
59f3984
795174e
e08daf1
59f3984
 
e08daf1
5ecb4a9
 
59f3984
 
 
 
 
 
4f9929e
1db955a
5db2f57
5ecb4a9
 
 
87670df
6619b46
5ecb4a9
6619b46
 
1ffdbc2
6619b46
1ffdbc2
0b120d5
6619b46
0b120d5
6619b46
b0f3145
7c672bb
7e06c4d
 
88dc089
7e06c4d
 
b0f3145
0ce0e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0f3145
a364bc6
b0f3145
 
75859e2
b0f3145
 
 
 
75859e2
b0f3145
 
 
75859e2
 
 
 
bef93f3
 
 
75859e2
bef93f3
 
 
0c6962b
 
bef93f3
75859e2
 
 
d20698c
75859e2
bef93f3
 
bbd2321
bef93f3
b0f3145
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import torch
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.image_processor import VaeImageProcessor
from transformers import CLIPImageProcessor
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

device = "cuda"
dtype = torch.float16

base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
opts = {
    "1 Step"  : ("sdxl_lightning_1step_unet_x0.safetensors", 1),
    "2 Steps" : ("sdxl_lightning_2step_unet.safetensors", 2),
    "4 Steps" : ("sdxl_lightning_4step_unet.safetensors", 4),
    "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
}

# Inference function.
@spaces.GPU()
def generate(prompt, option, progress=gr.Progress()):
    print(prompt, option)
    ckpt, step = opts[option]

    progress(0, desc="Initializing the model")

    # Main pipeline.
    unet = UNet2DConditionModel.from_config(base, subfolder="unet")
    pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
    pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")

    # Safety checker.
    safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device, dtype)
    feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
    image_processor = VaeImageProcessor(vae_scale_factor=8)

    def inference_callback(p, i, t, kwargs):
        progress((i+1, step))
        return kwargs
    
    # Inference loop.
    progress((0, step))
    results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback, output_type="pt")

    # Safety check.
    feature_extractor_input = image_processor.postprocess(results.images, output_type="pil")
    safety_checker_input = feature_extractor(feature_extractor_input, return_tensors="pt")
    pixel_values = safety_checker_input.pixel_values.to(device, dtype)
    images, has_nsfw_concept = safety_checker(
        images=results.images, clip_input=pixel_values
    )
    if has_nsfw_concept[0]:
        print(f"Safety checker triggered on prompt: {prompt}")
    return images[0]

with gr.Blocks(css="style.css") as demo:
    gr.HTML(
        "<h1><center>SDXL-Lightning</center></h1>" +
        "<p><center>Lightning-fast text-to-image generation</center></p>" +
        "<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
    )
    
    with gr.Row():
        prompt = gr.Textbox(
            label="Text prompt",
            scale=8
        )
        option = gr.Dropdown(
            label="Inference steps",
            choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
            value="4 Steps",
            interactive=True
        )
        submit = gr.Button(
            scale=1,
            variant="primary"
        )
    
    img = gr.Image(label="SDXL-Lightning Generated Image")

    prompt.submit(
        fn=generate,
        inputs=[prompt, option],
        outputs=img,
    )
    submit.click(
        fn=generate,
        inputs=[prompt, option],
        outputs=img,
    )

    gr.Examples(
        fn=generate,
        examples=[
            ["An owl perches quietly on a twisted branch deep within an ancient forest.", "1 Step"],
            ["A lion in the galaxy, octane render", "2 Steps"],
            ["A dolphin leaps through the waves, set against a backdrop of bright blues and teal hues.", "2 Steps"],
            ["A girl smiling", "4 Steps"],
            ["An astronaut riding a horse", "4 Steps"],
            ["A fish on a bicycle, colorful art", "4 Steps"],
            ["A close-up of an Asian lady with sunglasses.", "4 Steps"],
            ["Rabbit portrait in a forest, fantasy", "4 Steps"],
            ["A panda swimming", "4 Steps"],
            ["Man portrait, ethereal", "8 Steps"],
        ],
        inputs=[prompt, option],
        outputs=img,
        cache_examples=False,
    )

    gr.HTML(
        "<p><small><center>This demo is built together by the community</center></small></p>"
    )
    
demo.queue().launch()