File size: 5,797 Bytes
5c4b5eb
6b566c6
5c4b5eb
 
 
26cfd4b
5c4b5eb
 
 
26cfd4b
5c4b5eb
aeccac3
 
 
 
 
 
 
 
 
5c4b5eb
 
 
 
26cfd4b
 
cedad44
26cfd4b
 
 
cedad44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26cfd4b
 
cedad44
 
 
 
 
 
 
 
 
 
 
 
 
 
26cfd4b
cedad44
 
 
 
 
 
542ea3b
26cfd4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cedad44
26cfd4b
 
 
 
 
 
 
 
 
 
 
 
cedad44
 
 
 
 
 
 
 
542ea3b
cedad44
26cfd4b
 
 
cedad44
aeccac3
 
 
 
5c4b5eb
 
e07cbb0
 
 
 
8bfad22
 
 
 
 
5c4b5eb
cedad44
26cfd4b
 
 
 
 
 
 
 
 
 
 
 
5c4b5eb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import time
import spaces
import gradio as gr
import torch
import diffusers
from utils import patch_attention_proc, remove_patch
import math
import numpy as np
from PIL import Image
from threading import Semaphore

# Globals
css = """
h1 {
  text-align: center;
  display: block;
}
"""

# Pipeline
pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to("cuda", torch.float16)
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.safety_checker = None

semaphore = Semaphore() # for preventing collisions of two simultaneous button presses

@spaces.GPU
def generate_baseline(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):

    semaphore.acquire()

    downsample_factor = 2
    ratio = 0.38
    merge_method = "downsample" if method == "todo" else "similarity"
    merge_tokens = "keys/values" if method == "todo" else "all"

    if height_width == 1024:
        downsample_factor = 2
        ratio = 0.75
        downsample_factor_level_2 = 1
        ratio_level_2 = 0.0
    elif height_width == 1536:
        downsample_factor = 3
        ratio = 0.89
        downsample_factor_level_2 = 1
        ratio_level_2 = 0.0
    elif height_width == 2048:
        downsample_factor = 4
        ratio = 0.9375
        downsample_factor_level_2 = 1
        ratio_level_2 = 0.0

    token_merge_args = {"ratio": ratio,
                "merge_tokens": merge_tokens,
                "merge_method": merge_method,
                "downsample_method": "nearest",
                "downsample_factor": downsample_factor,
                "timestep_threshold_switch": 0.0,
                "timestep_threshold_stop": 0.0,
                "downsample_factor_level_2": downsample_factor_level_2,
                "ratio_level_2": ratio_level_2
                }

    torch.manual_seed(seed)
    start_time_base = time.time()
    remove_patch(pipe)
    base_img = pipe(prompt,
                    num_inference_steps=steps, height=height_width, width=height_width,
                    negative_prompt=negative_prompt,
                    guidance_scale=guidance_scale).images[0]
    end_time_base = time.time()

    result = f"Baseline Runtime: {end_time_base-start_time_base:.2f} sec"

    semaphore.release()

    return base_img, result


@spaces.GPU
def generate_merged(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
    
    semaphore.acquire()

    downsample_factor = 2
    ratio = 0.38
    merge_method = "downsample" if method == "todo" else "similarity"
    merge_tokens = "keys/values" if method == "todo" else "all"

    if height_width == 1024:
        downsample_factor = 2
        ratio = 0.75
        downsample_factor_level_2 = 1
        ratio_level_2 = 0.0
    elif height_width == 1536:
        downsample_factor = 3
        ratio = 0.89
        downsample_factor_level_2 = 1
        ratio_level_2 = 0.0
    elif height_width == 2048:
        downsample_factor = 4
        ratio = 0.9375
        downsample_factor_level_2 = 1
        ratio_level_2 = 0.0

    token_merge_args = {"ratio": ratio,
                "merge_tokens": merge_tokens,
                "merge_method": merge_method,
                "downsample_method": "nearest",
                "downsample_factor": downsample_factor,
                "timestep_threshold_switch": 0.0,
                "timestep_threshold_stop": 0.0,
                "downsample_factor_level_2": downsample_factor_level_2,
                "ratio_level_2": ratio_level_2
                }

    patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
    torch.manual_seed(seed)
    start_time_merge = time.time()
    merged_img = pipe(prompt,
                        num_inference_steps=steps, height=height_width, width=height_width,
                        negative_prompt=negative_prompt,
                        guidance_scale=guidance_scale).images[0]
    end_time_merge = time.time()

    result = f"{'ToDo' if method == 'todo' else 'ToMe'} Runtime: {end_time_merge-start_time_merge:.2f} sec"

    semaphore.release()

    return merged_img, result



with gr.Blocks(css=css) as demo:
    gr.Markdown("# ToDo: Token Downsampling for Efficient Generation of High-Resolution Images")
    prompt = gr.Textbox(interactive=True, label="prompt")
    negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
    
    with gr.Row():
        method = gr.Dropdown(["todo", "tome"], value="todo", label="method", info="Choose Your Desired Method (Default: todo)")
        height_width = gr.Dropdown([1024, 1536, 2048], value=1024, label="height/width", info="Choose Your Desired Height/Width (Default: 1024)")

    with gr.Row():
        guidance_scale = gr.Number(label="guidance_scale", value=7.5, precision=1)
        steps = gr.Number(label="steps", value=20, precision=0)
        seed = gr.Number(label="seed", value=1, precision=0)

    with gr.Row():
        with gr.Column():
            base_result = gr.Textbox(label="Baseline Runtime")
            base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False)
            gen = gr.Button("Generate Baseline")
            gen.click(generate_baseline, inputs=[prompt, seed, steps, height_width, negative_prompt,
                                        guidance_scale, method], outputs=[base_image, base_result])
        with gr.Column():
            output_result = gr.Textbox(label="Runtime")
            output_image = gr.Image(label=f"image", type="pil", interactive=False)
            gen = gr.Button("Generate")
            gen.click(generate_merged, inputs=[prompt, seed, steps, height_width, negative_prompt,
                                        guidance_scale, method], outputs=[output_image, output_result])

demo.launch(share=True)