aningineer commited on
Commit
26cfd4b
1 Parent(s): ab5735d

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +2 -1
  2. __pycache__/utils.cpython-310.pyc +0 -0
  3. app.py +70 -15
  4. utils.py +9 -0
README.md CHANGED
@@ -8,8 +8,9 @@ sdk_version: 4.19.2
8
  # ToDo: Token Downsampling for Efficient Generation of High-Resolution Images
9
  ---
10
 
11
- This is a demo for our recently proposed method, ["ToDo: Token Downsampling for Efficient Generation of High-Resolution Images"](https://arxiv.org/abs/2402.13573), compared against a popular token merging method, ToMe.
12
 
 
13
  ```
14
  @misc{smith2024todo,
15
  title={ToDo: Token Downsampling for Efficient Generation of High-Resolution Images},
 
8
  # ToDo: Token Downsampling for Efficient Generation of High-Resolution Images
9
  ---
10
 
11
+ We provide a [HuggingFace Spaces demo](https://huggingface.co/spaces/aningineer/ToDo) for our recently proposed method, ["ToDo: Token Downsampling for Efficient Generation of High-Resolution Images"](https://arxiv.org/abs/2402.13573), and compare it against a popular token merging method called ToMe.
12
 
13
+ If you consider our research to be helpful, please consider citing us:
14
  ```
15
  @misc{smith2024todo,
16
  title={ToDo: Token Downsampling for Efficient Generation of High-Resolution Images},
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
app.py CHANGED
@@ -3,10 +3,11 @@ import spaces
3
  import gradio as gr
4
  import torch
5
  import diffusers
6
- from utils import patch_attention_proc
7
  import math
8
  import numpy as np
9
  from PIL import Image
 
10
 
11
  # Globals
12
  css = """
@@ -21,8 +22,12 @@ pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to
21
  pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
22
  pipe.safety_checker = None
23
 
 
 
24
  @spaces.GPU
25
- def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
 
 
26
 
27
  downsample_factor = 2
28
  ratio = 0.38
@@ -42,8 +47,8 @@ def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale,
42
  elif height_width == 2048:
43
  downsample_factor = 4
44
  ratio = 0.9375
45
- downsample_factor_level_2 = 2
46
- ratio_level_2 = 0.75
47
 
48
  token_merge_args = {"ratio": ratio,
49
  "merge_tokens": merge_tokens,
@@ -56,17 +61,60 @@ def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale,
56
  "ratio_level_2": ratio_level_2
57
  }
58
 
59
- l_r = torch.rand(1).item()
60
  torch.manual_seed(seed)
61
  start_time_base = time.time()
 
62
  base_img = pipe(prompt,
63
  num_inference_steps=steps, height=height_width, width=height_width,
64
  negative_prompt=negative_prompt,
65
  guidance_scale=guidance_scale).images[0]
66
  end_time_base = time.time()
67
 
68
- patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  torch.manual_seed(seed)
71
  start_time_merge = time.time()
72
  merged_img = pipe(prompt,
@@ -75,9 +123,11 @@ def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale,
75
  guidance_scale=guidance_scale).images[0]
76
  end_time_merge = time.time()
77
 
78
- result = f"Baseline image: {end_time_base-start_time_base:.2f} sec | {'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"
79
 
80
- return base_img, merged_img, result
 
 
81
 
82
 
83
 
@@ -95,13 +145,18 @@ with gr.Blocks(css=css) as demo:
95
  steps = gr.Number(label="steps", value=20, precision=0)
96
  seed = gr.Number(label="seed", value=1, precision=0)
97
 
98
- result = gr.Textbox(label="Result")
99
  with gr.Row():
100
- base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False)
101
- output_image = gr.Image(label=f"output_image", type="pil", interactive=False)
102
-
103
- gen = gr.Button("generate")
104
- gen.click(generate, inputs=[prompt, seed, steps, height_width, negative_prompt,
105
- guidance_scale, method], outputs=[base_image, output_image, result])
 
 
 
 
 
 
106
 
107
  demo.launch(share=True)
 
3
  import gradio as gr
4
  import torch
5
  import diffusers
6
+ from utils import patch_attention_proc, remove_patch
7
  import math
8
  import numpy as np
9
  from PIL import Image
10
+ from threading import Semaphore
11
 
12
  # Globals
13
  css = """
 
22
  pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
23
  pipe.safety_checker = None
24
 
25
+ semaphore = Semaphore() # for preventing collisions of two simultaneous button presses
26
+
27
  @spaces.GPU
28
+ def generate_baseline(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
29
+
30
+ semaphore.acquire()
31
 
32
  downsample_factor = 2
33
  ratio = 0.38
 
47
  elif height_width == 2048:
48
  downsample_factor = 4
49
  ratio = 0.9375
50
+ downsample_factor_level_2 = 1
51
+ ratio_level_2 = 0.0
52
 
53
  token_merge_args = {"ratio": ratio,
54
  "merge_tokens": merge_tokens,
 
61
  "ratio_level_2": ratio_level_2
62
  }
63
 
 
64
  torch.manual_seed(seed)
65
  start_time_base = time.time()
66
+ remove_patch(pipe)
67
  base_img = pipe(prompt,
68
  num_inference_steps=steps, height=height_width, width=height_width,
69
  negative_prompt=negative_prompt,
70
  guidance_scale=guidance_scale).images[0]
71
  end_time_base = time.time()
72
 
73
+ result = f"Baseline image: {end_time_base-start_time_base:.2f} sec"
74
+
75
+ semaphore.release()
76
+
77
+ return base_img, result
78
+
79
+
80
+ @spaces.GPU
81
+ def generate_merged(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
82
+
83
+ semaphore.acquire()
84
+
85
+ downsample_factor = 2
86
+ ratio = 0.38
87
+ merge_method = "downsample" if method == "todo" else "similarity"
88
+ merge_tokens = "keys/values" if method == "todo" else "all"
89
+
90
+ if height_width == 1024:
91
+ downsample_factor = 2
92
+ ratio = 0.75
93
+ downsample_factor_level_2 = 1
94
+ ratio_level_2 = 0.0
95
+ elif height_width == 1536:
96
+ downsample_factor = 3
97
+ ratio = 0.89
98
+ downsample_factor_level_2 = 1
99
+ ratio_level_2 = 0.0
100
+ elif height_width == 2048:
101
+ downsample_factor = 4
102
+ ratio = 0.9375
103
+ downsample_factor_level_2 = 1
104
+ ratio_level_2 = 0.0
105
 
106
+ token_merge_args = {"ratio": ratio,
107
+ "merge_tokens": merge_tokens,
108
+ "merge_method": merge_method,
109
+ "downsample_method": "nearest",
110
+ "downsample_factor": downsample_factor,
111
+ "timestep_threshold_switch": 0.0,
112
+ "timestep_threshold_stop": 0.0,
113
+ "downsample_factor_level_2": downsample_factor_level_2,
114
+ "ratio_level_2": ratio_level_2
115
+ }
116
+
117
+ patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
118
  torch.manual_seed(seed)
119
  start_time_merge = time.time()
120
  merged_img = pipe(prompt,
 
123
  guidance_scale=guidance_scale).images[0]
124
  end_time_merge = time.time()
125
 
126
+ result = f"{'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"
127
 
128
+ semaphore.release()
129
+
130
+ return merged_img, result
131
 
132
 
133
 
 
145
  steps = gr.Number(label="steps", value=20, precision=0)
146
  seed = gr.Number(label="seed", value=1, precision=0)
147
 
 
148
  with gr.Row():
149
+ with gr.Column():
150
+ base_result = gr.Textbox(label="Baseline Runtime")
151
+ base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False)
152
+ gen = gr.Button("Generate Baseline")
153
+ gen.click(generate_baseline, inputs=[prompt, seed, steps, height_width, negative_prompt,
154
+ guidance_scale, method], outputs=[base_image, base_result])
155
+ with gr.Column():
156
+ output_result = gr.Textbox(label="Runtime")
157
+ output_image = gr.Image(label=f"image", type="pil", interactive=False)
158
+ gen = gr.Button("Generate")
159
+ gen.click(generate_merged, inputs=[prompt, seed, steps, height_width, negative_prompt,
160
+ guidance_scale, method], outputs=[output_image, output_result])
161
 
162
  demo.launch(share=True)
utils.py CHANGED
@@ -25,6 +25,15 @@ def hook_tome_model(model: torch.nn.Module):
25
 
26
  model._tome_info["hooks"].append(model.register_forward_pre_hook(hook))
27
 
 
 
 
 
 
 
 
 
 
28
 
29
  def patch_attention_proc(unet, token_merge_args={}):
30
  unet._tome_info = {
 
25
 
26
  model._tome_info["hooks"].append(model.register_forward_pre_hook(hook))
27
 
28
+ def remove_patch(pipe: torch.nn.Module):
29
+ """ Removes a patch from a ToMe Diffusion module if it was already patched. """
30
+
31
+ if hasattr(pipe.unet, "_tome_info"):
32
+ del pipe.unet._tome_info
33
+
34
+ for n,m in pipe.unet.named_modules():
35
+ if hasattr(m, "processor"):
36
+ m.processor = AttnProcessor2_0()
37
 
38
  def patch_attention_proc(unet, token_merge_args={}):
39
  unet._tome_info = {