File size: 4,509 Bytes
63e9337
1600949
 
 
 
 
 
 
 
cdd67f8
1600949
df905b7
 
1600949
6994be4
1600949
d3bc7ba
f8ee0c1
1600949
6994be4
 
 
1600949
 
6994be4
 
 
1600949
 
6994be4
1600949
 
 
 
 
 
 
df905b7
1600949
 
 
 
6994be4
1600949
 
 
 
 
 
 
6994be4
 
1600949
 
6994be4
 
 
 
 
1600949
 
 
 
 
 
 
 
 
 
 
 
 
 
6994be4
1600949
 
 
df905b7
9a76fcf
1600949
 
 
 
 
 
 
 
 
 
 
 
 
6994be4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1600949
6994be4
 
 
 
 
 
 
 
 
 
 
 
 
1600949
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import spaces
import base64
from io import BytesIO

import gradio as gr
import PIL.Image
import torch

from diffusers import StableDiffusionPipeline, AutoencoderKL, AutoencoderTiny
from peft import PeftModel

device = "cuda"
weight_type = torch.float16

pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper")
pipe.unet = PeftModel.from_pretrained(pipe.unet, "IDKiro/sdxs-512-dreamshaper-anime")
pipe.unet.merge_and_unload()
pipe.to(device, dtype=weight_type)

vae_tiny = AutoencoderTiny.from_pretrained(
    "IDKiro/sdxs-512-dreamshaper", subfolder="vae"
)
vae_tiny.to(device, dtype=weight_type)

vae_large = AutoencoderKL.from_pretrained(
    "IDKiro/sdxs-512-dreamshaper", subfolder="vae_large"
)
vae_tiny.to(device, dtype=weight_type)


def pil_image_to_data_url(img, format="PNG"):
    buffered = BytesIO()
    img.save(buffered, format=format)
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f"data:image/{format.lower()};base64,{img_str}"


@spaces.GPU
def run(
    prompt: str,
    device_type="GPU",
    vae_type=None,
    param_dtype="torch.float16",
) -> PIL.Image.Image:
    if vae_type == "tiny vae":
        pipe.vae = vae_tiny
    elif vae_type == "large vae":
        pipe.vae = vae_large

    if device_type == "CPU":
        device = "cpu"
        param_dtype = "torch.float32"
    else:
        device = "cuda"

    pipe.to(
        torch_device=device,
        torch_dtype=torch.float16 if param_dtype == "torch.float16" else torch.float32,
    )

    result = pipe(
        prompt=prompt,
        guidance_scale=0.0,
        num_inference_steps=1,
        output_type="pil",
    ).images[0]

    result_url = pil_image_to_data_url(result)

    return (result, result_url)


examples = [
    "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
]

with gr.Blocks(css="style.css") as demo:
    gr.Markdown("# SDXS-512-DreamShaper-Anime")
    gr.Markdown("[SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions](https://arxiv.org/abs/2403.16627) | [GitHub](https://github.com/IDKiro/sdxs)")
    with gr.Group():
        with gr.Row():
            with gr.Column(min_width=685):
                with gr.Row():
                    prompt = gr.Text(
                        label="Prompt",
                        show_label=False,
                        max_lines=1,
                        placeholder="Enter your prompt",
                        container=False,
                    )
                    run_button = gr.Button("Run", scale=0)

                device_choices = ["GPU", "CPU"]
                device_type = gr.Radio(
                    device_choices,
                    label="Device",
                    value=device_choices[0],
                    interactive=True,
                    info="Thanks to the community for the GPU!",
                )

                vae_choices = ["tiny vae", "large vae"]
                vae_type = gr.Radio(
                    vae_choices,
                    label="Image Decoder Type",
                    value=vae_choices[0],
                    interactive=True,
                    info="To save GPU memory, use tiny vae. For better quality, use large vae.",
                )

                dtype_choices = ["torch.float16", "torch.float32"]
                param_dtype = gr.Radio(
                    dtype_choices,
                    label="torch.weight_type",
                    value=dtype_choices[0],
                    interactive=True,
                    info="To save GPU memory, use torch.float16. For better quality, use torch.float32.",
                )

                download_output = gr.Button(
                    "Download output", elem_id="download_output"
                )

            with gr.Column(min_width=512):
                result = gr.Image(
                    label="Result",
                    height=512,
                    width=512,
                    elem_id="output_image",
                    show_label=False,
                    show_download_button=True,
                )

    gr.Examples(examples=examples, inputs=prompt, outputs=result, fn=run)

    demo.load(None, None, None)

    inputs = [prompt, device_type, vae_type, param_dtype]
    outputs = [result, download_output]
    prompt.submit(fn=run, inputs=inputs, outputs=outputs)
    run_button.click(fn=run, inputs=inputs, outputs=outputs)

if __name__ == "__main__":
    demo.queue().launch(debug=True)