File size: 2,915 Bytes
b0f3145
 
 
8425644
b0f3145
 
 
84c6537
 
 
 
88dc089
b0f3145
 
 
8425644
 
 
 
b0f3145
 
23f3ac6
7c672bb
23f3ac6
 
7c672bb
23f3ac6
b0f3145
23f3ac6
b0f3145
4f9929e
c811b57
7c672bb
b0f3145
4f9929e
c811b57
7c672bb
f1e3c7d
 
8425644
4f9929e
 
5db2f57
4f9929e
b0f3145
7c672bb
7e06c4d
 
88dc089
7e06c4d
 
b0f3145
0ce0e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0f3145
7c672bb
b0f3145
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

assert torch.cuda.is_available()

device = "cuda"
dtype = torch.float16

base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
opts = {
    "1 Step"  : ("sdxl_lightning_1step_unet_x0.safetensors", 1),
    "2 Steps" : ("sdxl_lightning_2step_unet.safetensors", 2),
    "4 Steps" : ("sdxl_lightning_4step_unet.safetensors", 4),
    "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
}

# Default to load 4-step model.
step_loaded = 4
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0])))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

# Inference function.
@spaces.GPU(enable_queue=True)
def generate_image(prompt, option, progress=gr.Progress()):
    global step_loaded
    print(prompt, option)
    ckpt, step = opts[option]
    progress(0, total=step)
    if step != step_loaded:
        print(f"Switching checkpoint from {step_loaded} to {step}")
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
        pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
        step_loaded = step
    def inference_callback(p, i, t, kwargs):
        progress(i+1, total=step)
        return kwargs
    return pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback).images[0]

with gr.Blocks(css="style.css") as demo:
    gr.HTML(
        "<h1><center>SDXL-Lightning</center></h1>" +
        "<p><center>Lightning-fast text-to-image generation</center></p>" +
        "<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
    )
    
    with gr.Row():
        prompt = gr.Textbox(
            label="Text prompt",
            scale=8
        )
        option = gr.Dropdown(
            label="Inference steps",
            choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
            value="4 Steps",
            interactive=True
        )
        submit = gr.Button(
            scale=1,
            variant="primary"
        )
    
    img = gr.Image(label="SDXL-Lighting Generated Image")

    prompt.submit(
        fn=generate_image,
        inputs=[prompt, option],
        outputs=img,
    )
    submit.click(
        fn=generate_image,
        inputs=[prompt, option],
        outputs=img,
    )
    
demo.queue().launch()