File size: 3,838 Bytes
b0f3145
 
 
8425644
b0f3145
 
 
84c6537
 
 
 
88dc089
b0f3145
 
 
8425644
 
 
 
b0f3145
 
23f3ac6
7c672bb
23f3ac6
 
7c672bb
23f3ac6
b0f3145
23f3ac6
b0f3145
75859e2
c811b57
7c672bb
b0f3145
1db955a
c811b57
7c672bb
f1e3c7d
 
8425644
4f9929e
1db955a
5db2f57
4f9929e
b0f3145
7c672bb
7e06c4d
 
88dc089
7e06c4d
 
b0f3145
0ce0e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0f3145
7c672bb
b0f3145
 
75859e2
b0f3145
 
 
 
75859e2
b0f3145
 
 
75859e2
 
 
 
bef93f3
 
 
75859e2
bef93f3
 
 
 
 
 
75859e2
 
 
 
 
bef93f3
 
bbd2321
bef93f3
b0f3145
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import torch
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

assert torch.cuda.is_available()

device = "cuda"
dtype = torch.float16

base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
opts = {
    "1 Step"  : ("sdxl_lightning_1step_unet_x0.safetensors", 1),
    "2 Steps" : ("sdxl_lightning_2step_unet.safetensors", 2),
    "4 Steps" : ("sdxl_lightning_4step_unet.safetensors", 4),
    "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
}

# Default to load 4-step model.
step_loaded = 4
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0])))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

# Inference function.
@spaces.GPU(enable_queue=True)
def generate(prompt, option, progress=gr.Progress()):
    global step_loaded
    print(prompt, option)
    ckpt, step = opts[option]
    progress((0, step))
    if step != step_loaded:
        print(f"Switching checkpoint from {step_loaded} to {step}")
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
        pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
        step_loaded = step
    def inference_callback(p, i, t, kwargs):
        progress((i+1, step))
        return kwargs
    return pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback).images[0]

with gr.Blocks(css="style.css") as demo:
    gr.HTML(
        "<h1><center>SDXL-Lightning</center></h1>" +
        "<p><center>Lightning-fast text-to-image generation</center></p>" +
        "<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
    )
    
    with gr.Row():
        prompt = gr.Textbox(
            label="Text prompt",
            scale=8
        )
        option = gr.Dropdown(
            label="Inference steps",
            choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
            value="4 Steps",
            interactive=True
        )
        submit = gr.Button(
            scale=1,
            variant="primary"
        )
    
    img = gr.Image(label="SDXL-Lighting Generated Image")

    prompt.submit(
        fn=generate,
        inputs=[prompt, option],
        outputs=img,
    )
    submit.click(
        fn=generate,
        inputs=[prompt, option],
        outputs=img,
    )

    gr.Examples(
        fn=generate,
        examples=[
            ["An owl perches quietly on a twisted branch deep within an ancient forest.", "1 Step"],
            ["A lion in the galaxy, octane render", "2 Steps"],
            ["A dolphin leaps through the waves, set against a backdrop of bright blues and teal hues.", "2 Steps"],
            ["A girl smiling", "4 Steps"],
            ["An astronaut riding a horse", "4 Steps"],
            ["A fish on a bicycle, colorful art", "4 Steps"],
            ["A close-up of an Asian lady with sunglasses.", "4 Steps"],
            ["Man portrait, ethereal", "8 Steps"],
            ["Rabbit portrait in a forest, fantasy", "8 Steps"],
            ["A panda swimming", "8 Steps"],
        ],
        inputs=[prompt, option],
        outputs=img,
        cache_examples=True,
    )

    gr.HTML(
        "<p><small><center>This demo is built together by the community</center></small></p>"
    )
    
demo.queue().launch()