File size: 5,356 Bytes
c2c42ca
d827a95
81435cb
61bc6a3
 
6da6b11
9d41bd5
a1f66f7
 
c2c42ca
aa5a24b
 
 
 
 
91dd651
c2c42ca
61bc6a3
 
 
 
c2c42ca
aa5a24b
 
 
 
 
 
 
 
 
 
 
 
 
61bc6a3
 
c2c42ca
61bc6a3
 
 
143f063
aa5a24b
 
 
 
 
 
 
 
 
 
 
 
61bc6a3
 
aa5a24b
91dd651
6da6b11
09f0b4e
61bc6a3
09f0b4e
61bc6a3
09f0b4e
61bc6a3
d6bdfdf
143f063
b634b72
ddbaa70
 
 
 
 
 
 
 
61bc6a3
dc81866
d452942
c2c42ca
 
 
dc81866
 
 
 
 
 
 
 
 
ddbaa70
bc87ae3
 
09f0b4e
bc87ae3
09f0b4e
ddbaa70
 
 
09f0b4e
ddbaa70
4835fe3
09f0b4e
 
bc87ae3
5e64d98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, LCMScheduler, AutoencoderKL
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

### SDXL Turbo #### 
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
                                                       vae=vae,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )
pipe_turbo.to("cuda")

### SDXL Lightning ### 
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_1step_unet_x0.safetensors" 

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe_lightning = StableDiffusionXLPipeline.from_pretrained(base,
                                                           unet=unet,
                                                           vae=vae,
                                                           text_encoder=pipe_turbo.text_encoder,
                                                           text_encoder_2=pipe_turbo.text_encoder_2,
                                                           tokenizer=pipe_turbo.tokenizer,
                                                           tokenizer_2=pipe_turbo.tokenizer_2,
                                                           torch_dtype=torch.float16,
                                                           variant="fp16"
                                                          )#.to("cuda")
del unet
pipe_lightning.scheduler = EulerDiscreteScheduler.from_config(pipe_lightning.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
pipe_lightning.to("cuda")

### Hyper SDXL ### 
repo_name = "ByteDance/Hyper-SD"
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo_name, ckpt_name)))
pipe_hyper = StableDiffusionXLPipeline.from_pretrained(base,
                                                       unet=unet,
                                                       vae=vae,
                                                       text_encoder=pipe_turbo.text_encoder,
                                                       text_encoder_2=pipe_turbo.text_encoder_2,
                                                       tokenizer=pipe_turbo.tokenizer,
                                                       tokenizer_2=pipe_turbo.tokenizer_2,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )#.to("cuda")
pipe_hyper.scheduler = LCMScheduler.from_config(pipe_hyper.scheduler.config)
pipe_hyper.to("cuda")
del unet

@spaces.GPU
def run_comparison(prompt, progress=gr.Progress(track_tqdm=True)):
    image_turbo=pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    yield image_turbo, None, None
    image_lightning=pipe_lightning(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    yield image_turbo, image_lightning, None
    image_hyper=pipe_hyper(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[800]).images[0]
    yield image_turbo, image_lightning, image_hyper

examples = ["A dignified beaver wearing glasses, a vest, and colorful neck tie.",
"The spirit of a tamagotchi wandering in the city of Barcelona",
"an ornate, high-backed mahogany chair with a red cushion",
"a sketch of a camel next to a stream",
"a delicate porcelain teacup sits on a saucer, its surface adorned with intricate blue patterns",
"a baby swan grafitti",
"A bald eagle made of chocolate powder, mango, and whipped cream"
]

with gr.Blocks() as demo:
    gr.Markdown("## One step SDXL comparison 🦶")
    gr.Markdown('Compare SDXL variants and distillations able to generate images in a single diffusion step')
    prompt = gr.Textbox(label="Prompt")
    run = gr.Button("Run")
    with gr.Row():
        with gr.Column():
            image_turbo = gr.Image(label="SDXL Turbo")
            gr.Markdown("## [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo)")
        with gr.Column():
            image_lightning = gr.Image(label="SDXL Lightning")
            gr.Markdown("## [SDXL Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)")
        with gr.Column():
            image_hyper = gr.Image(label="Hyper SDXL")
            gr.Markdown("## [Hyper SDXL](https://huggingface.co/ByteDance/Hyper-SD)")
    image_outputs = [image_turbo, image_lightning, image_hyper]
    gr.on(
        triggers=[prompt.submit, run.click],
        fn=run_comparison,
        inputs=prompt,
        outputs=image_outputs
    )
    gr.Examples(
        examples=examples,
        fn=run_comparison,
        inputs=prompt,
        outputs=image_outputs,
        cache_examples=False,
        run_on_click=True
    )
demo.launch()