File size: 4,286 Bytes
1600949
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9824b10
1600949
 
 
b953745
1600949
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3bbdc9
 
1600949
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import base64
from io import BytesIO

import gradio as gr
import PIL.Image
import torch

from diffusers import StableDiffusionPipeline, AutoencoderKL, AutoencoderTiny
from peft import PeftModel

device = "cpu"   # Linux & Windows
weight_type = torch.float32  # torch.float16 works as well, but pictures seem to be a bit worse

pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper", torch_dtype=weight_type)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "IDKiro/sdxs-512-dreamshaper-anime")
pipe.to(torch_device=device, torch_dtype=weight_type)

vae_tiny = AutoencoderTiny.from_pretrained("IDKiro/sdxs-512-dreamshaper", subfolder="vae")
vae_tiny.to(device, dtype=weight_type)

vae_large = AutoencoderKL.from_pretrained("IDKiro/sdxs-512-dreamshaper", subfolder="vae_large")
vae_tiny.to(device, dtype=weight_type)

def pil_image_to_data_url(img, format="PNG"):
    buffered = BytesIO()
    img.save(buffered, format=format)
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f"data:image/{format.lower()};base64,{img_str}"


def run(
    prompt: str,
    device_type="GPU",
    vae_type=None,
    param_dtype='torch.float16',
) -> PIL.Image.Image:
    if vae_type == "tiny vae":
        pipe.vae = vae_tiny
    elif vae_type == "large vae":
        pipe.vae = vae_large

    if device_type == "CPU":
        device = "cpu" 
        param_dtype = 'torch.float32'
    else:
        device = "cuda"
    
    pipe.to(torch_device=device, torch_dtype=torch.float16 if param_dtype == 'torch.float16' else torch.float32)

    result = pipe(
        prompt=prompt,
        guidance_scale=0.0,
        num_inference_steps=1,
        output_type="pil",
    ).images[0]

    result_url = pil_image_to_data_url(result)

    return (result, result_url)


examples = [
    "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
]

with gr.Blocks(css="style.css") as demo:
    gr.Markdown("# SDXS-512-DreamShaper-Anime (only CPU now)")
    with gr.Group():
        with gr.Row():
            with gr.Column(min_width=685):
                with gr.Row():
                    prompt = gr.Text(
                        label="Prompt",
                        show_label=False,
                        max_lines=1,
                        placeholder="Enter your prompt",
                        container=False,
                    )
                    run_button = gr.Button("Run", scale=0)
                    
                device_choices = ['GPU','CPU']
                device_type = gr.Radio(device_choices, label='Device',  
                                            value=device_choices[1],
                                            interactive=False,
                                            info='Only CPU now.')

                vae_choices = ['tiny vae','large vae']
                vae_type = gr.Radio(vae_choices, label='Image Decoder Type',  
                                            value=vae_choices[0],
                                            interactive=True,
                                            info='To save GPU memory, use tiny vae. For better quality, use large vae.')

                dtype_choices = ['torch.float16','torch.float32']
                param_dtype = gr.Radio(dtype_choices,label='torch.weight_type',  
                                            value=dtype_choices[0],
                                            interactive=True,
                                            info='To save GPU memory, use torch.float16. For better quality, use torch.float32.')
                
                download_output = gr.Button("Download output", elem_id="download_output")

            with gr.Column(min_width=512):
                result = gr.Image(label="Result", height=512, width=512, elem_id="output_image", show_label=False, show_download_button=True)

    gr.Examples(
        examples=examples,
        inputs=prompt,
        outputs=result,
        fn=run
    )

    demo.load(None,None,None)

    inputs = [prompt, device_type, vae_type, param_dtype]
    outputs = [result, download_output]
    prompt.submit(fn=run, inputs=inputs, outputs=outputs)
    run_button.click(fn=run, inputs=inputs, outputs=outputs)

if __name__ == "__main__":
    demo.queue().launch(debug=True)