aningineer commited on
Commit
5c4b5eb
1 Parent(s): 6f5b8d4

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ compare3.png filter=lfs diff=lfs merge=lfs -text
37
+ test_notebook.ipynb filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,46 @@
1
  ---
2
  title: ToDo
3
- emoji: 🏃
4
- colorFrom: yellow
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.19.2
8
  app_file: app.py
9
- pinned: false
10
- license: mit
11
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: ToDo
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 3.50.2
6
  ---
7
+ # ImprovedTokenMerge
8
+ ![compare3.png](compare3.png)
9
+ ![GEuoFn1bMAABQqD](https://github.com/ethansmith2000/ImprovedTokenMerge/assets/98723285/82e03423-81e6-47da-afa4-9c1b2c1c4aeb)
10
+
11
+ twitter thread explanation: https://twitter.com/Ethan_smith_20/status/1750533558509433137
12
+
13
+ heavily inspired by https://github.com/dbolya/tomesd by @dbolya, a big thanks to the original authors.
14
+
15
+ This project aims to adress some of the shortcomings of Token Merging for Stable Diffusion. Namely consistenly faster inference without quality loss.
16
+ I found with the original that you would have to use a high merging ratio to get really any speedups at all, and by then quality was tarnished. Benchmarks here: https://github.com/dbolya/tomesd/issues/19#issuecomment-1507593483
17
+
18
+
19
+
20
+ I propose two changes to the original to solve this.
21
+ 1. Merging Method
22
+ - the original calculates a similarity matrix of the input tokens and merges those with highest similarity
23
+ - an issue here is that similarity calculation is O(n2) time, for ViT where token merging was proposed, you only had to do this a few times so it was quite efficient
24
+ - here it needs to be done at every step, and the computation ends up being nearly as costly as attention itself
25
+ - We can leverage a simple obsevation that nearby tokens tend to be similar to each other.
26
+ - therefore we can merge tokens via downsampling which is very cheap and seems to be a good approximation
27
+ - this can be analogized to grid-based subsampling of an image when using a nearest-neighbor downsample method, this is similar to what DiNAT (dilated neigborhood attention) does except for the fact we are still making use of global context
28
+ 2. Merge Targets
29
+ - the original merges the input tokens to attention, and then "unmerges" the resulting tokens to the original size
30
+ - this operation seems to be quite lossy
31
+ - instead i propose simply downsampling keys/values of the attention operation. both the QK calculation and QK * V can still drastically be reduced from the typical O(n2) scaling of attention, without needing to unmerge anything
32
+ - queries are left fully intact, they just attend more sparsely to the image
33
+ - attention for images, especially at larger resolutions, seems to be very sparse in general (QK matrix is low rank) so it does not appear that we lose too much from this
34
+
35
+ putting this altogether we can get tangible speedups of ~1.5x at typical sizes like 768-1024 and up to 3x and beyond at 1536 to 2048 range, in combination with flash attention
36
+
37
+
38
+ # Setup 🛠
39
+ ```
40
+ pip install -r requirements.txt
41
+ ```
42
+
43
+ # Inference 🚀
44
+ See the provided notebook, or gradio demo which you can run with python app.py
45
+
46
 
 
__pycache__/merge.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.66 kB). View file
 
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import gradio as gr
3
+ import torch
4
+ import diffusers
5
+ from utils import patch_attention_proc
6
+ import math
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to("cuda", torch.float16)
11
+ pipe.enable_xformers_memory_efficient_attention()
12
+ pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
13
+ pipe.safety_checker = None
14
+
15
+ with gr.Blocks() as demo:
16
+ prompt = gr.Textbox(interactive=True, label="prompt")
17
+ negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
18
+ method = gr.Dropdown(["todo", "tome"], value="todo", label="method", info="Choose Your Desired Method (Default: todo)")
19
+ height_width = gr.Dropdown([1024, 1536, 2048], value=1024, label="height/width", info="Choose Your Desired Height/Width (Default: 1024)")
20
+ # height = gr.Number(label="height", value=1024, precision=0)
21
+ # width = gr.Number(label="width", value=1024, precision=0)
22
+ guidance_scale = gr.Number(label="guidance_scale", value=7.5, precision=1)
23
+ steps = gr.Number(label="steps", value=20, precision=0)
24
+ seed = gr.Number(label="seed", value=1, precision=0)
25
+ result = gr.Textbox(label="Result")
26
+
27
+ output_image = gr.Image(label=f"output_image", type="pil", interactive=False)
28
+
29
+ gen = gr.Button("generate")
30
+
31
+ def which_image(img, target_val=253, width=1024):
32
+ npimg = np.array(img)
33
+ loc = np.where(npimg[:, :, 3] == target_val)[1].item()
34
+ if loc > width:
35
+ print("Right Image is merged!")
36
+ else:
37
+ print("Left Image is merged!")
38
+
39
+
40
+ def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
41
+
42
+ pipe.enable_xformers_memory_efficient_attention()
43
+
44
+ downsample_factor = 2
45
+ ratio = 0.38
46
+ merge_method = "downsample" if method == "todo" else "similarity"
47
+ merge_tokens = "keys/values" if method == "todo" else "all"
48
+
49
+ if height_width == 1024:
50
+ downsample_factor = 2
51
+ ratio = 0.75
52
+ downsample_factor_level_2 = 1
53
+ ratio_level_2 = 0.0
54
+ elif height_width == 1536:
55
+ downsample_factor = 3
56
+ ratio = 0.89
57
+ downsample_factor_level_2 = 1
58
+ ratio_level_2 = 0.0
59
+ elif height_width == 2048:
60
+ downsample_factor = 4
61
+ ratio = 0.9375
62
+ downsample_factor_level_2 = 2
63
+ ratio_level_2 = 0.75
64
+
65
+ token_merge_args = {"ratio": ratio,
66
+ "merge_tokens": merge_tokens,
67
+ "merge_method": merge_method,
68
+ "downsample_method": "nearest",
69
+ "downsample_factor": downsample_factor,
70
+ "timestep_threshold_switch": 0.0,
71
+ "timestep_threshold_stop": 0.0,
72
+ "downsample_factor_level_2": downsample_factor_level_2,
73
+ "ratio_level_2": ratio_level_2
74
+ }
75
+
76
+ l_r = torch.rand(1).item()
77
+ torch.manual_seed(seed)
78
+ start_time_base = time.time()
79
+ base_img = pipe(prompt,
80
+ num_inference_steps=steps, height=height_width, width=height_width,
81
+ negative_prompt=negative_prompt,
82
+ guidance_scale=guidance_scale).images[0]
83
+ end_time_base = time.time()
84
+
85
+ patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
86
+
87
+ torch.manual_seed(seed)
88
+ start_time_merge = time.time()
89
+ merged_img = pipe(prompt,
90
+ num_inference_steps=steps, height=height_width, width=height_width,
91
+ negative_prompt=negative_prompt,
92
+ guidance_scale=guidance_scale).images[0]
93
+ end_time_merge = time.time()
94
+
95
+ base_img = base_img.convert("RGBA")
96
+ merged_img = merged_img.convert("RGBA")
97
+ merged_img = np.array(merged_img)
98
+ halfh, halfw = height_width // 2, height_width // 2
99
+ merged_img[halfh, halfw, 3] = 253 # set the center pixel of the merged image to be ever so slightly below 255 in alpha channel
100
+ merged_img = Image.fromarray(merged_img)
101
+ final_img = Image.new(size=(height_width * 2, height_width), mode="RGBA")
102
+
103
+ if l_r > 0.5:
104
+ left_img = base_img
105
+ right_img = merged_img
106
+ else:
107
+ left_img = merged_img
108
+ right_img = base_img
109
+
110
+ final_img.paste(left_img, (0, 0))
111
+ final_img.paste(right_img, (height_width, 0))
112
+
113
+ which_image(final_img, width=height_width)
114
+
115
+
116
+ result = f"Baseline image: {end_time_base-start_time_base:.2f} sec | {'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"
117
+
118
+ return final_img, result
119
+
120
+
121
+ gen.click(generate, inputs=[prompt, seed, steps, height_width, negative_prompt,
122
+ guidance_scale, method], outputs=[output_image, result])
123
+
124
+ demo.launch(share=True)
compare3.png ADDED

Git LFS Details

  • SHA256: 2eab6aa0b551d24f4f485e1d614635053a416c5220385b6740385ca21c335c88
  • Pointer size: 133 Bytes
  • Size of remote file: 16.7 MB
merge.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple, Callable
3
+ from diffusers.models.attention_processor import XFormersAttnProcessor, Attention
4
+ import xformers, xformers.ops
5
+ from typing import Optional
6
+ import math
7
+ import torch.nn.functional as F
8
+ from diffusers.utils import USE_PEFT_BACKEND
9
+ from diffusers.utils.import_utils import is_xformers_available
10
+
11
+ if is_xformers_available():
12
+ import xformers
13
+ import xformers.ops
14
+ xformers_is_available = True
15
+ else:
16
+ xformers_is_available = False
17
+
18
+
19
+ if hasattr(F, "scaled_dot_product_attention"):
20
+ torch2_is_available = True
21
+ else:
22
+ torch2_is_available = False
23
+
24
+
25
+ def init_generator(device: torch.device, fallback: torch.Generator = None):
26
+ """
27
+ Forks the current default random generator given device.
28
+ """
29
+ if device.type == "cpu":
30
+ return torch.Generator(device="cpu").set_state(torch.get_rng_state())
31
+ elif device.type == "cuda":
32
+ return torch.Generator(device=device).set_state(torch.cuda.get_rng_state())
33
+ else:
34
+ if fallback is None:
35
+ return init_generator(torch.device("cpu"))
36
+ else:
37
+ return fallback
38
+
39
+
40
+ def do_nothing(x: torch.Tensor, mode: str = None):
41
+ return x
42
+
43
+
44
+ def mps_gather_workaround(input, dim, index):
45
+ if input.shape[-1] == 1:
46
+ return torch.gather(
47
+ input.unsqueeze(-1),
48
+ dim - 1 if dim < 0 else dim,
49
+ index.unsqueeze(-1)
50
+ ).squeeze(-1)
51
+ else:
52
+ return torch.gather(input, dim, index)
53
+
54
+
55
+ def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method):
56
+ batch_size = item.shape[0]
57
+
58
+ item = item.reshape(batch_size, cur_h, cur_w, -1)
59
+ item = item.permute(0, 3, 1, 2)
60
+ df = cur_h // new_h
61
+ if method in "max_pool":
62
+ item = F.max_pool2d(item, kernel_size=df, stride=df, padding=0)
63
+ elif method in "avg_pool":
64
+ item = F.avg_pool2d(item, kernel_size=df, stride=df, padding=0)
65
+ else:
66
+ item = F.interpolate(item, size=(new_h, new_w), mode=method)
67
+ item = item.permute(0, 2, 3, 1)
68
+ item = item.reshape(batch_size, new_h * new_w, -1)
69
+
70
+ return item
71
+
72
+
73
+ def compute_merge(x: torch.Tensor, tome_info):
74
+ original_h, original_w = tome_info["size"]
75
+ original_tokens = original_h * original_w
76
+ downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
77
+ dim = x.shape[-1]
78
+ if dim == 320:
79
+ cur_level = "level_1"
80
+ downsample_factor = tome_info['args']['downsample_factor']
81
+ ratio = tome_info['args']['ratio']
82
+ elif dim == 640:
83
+ cur_level = "level_2"
84
+ downsample_factor = tome_info['args']['downsample_factor_level_2']
85
+ ratio = tome_info['args']['ratio_level_2']
86
+ else:
87
+ cur_level = "other"
88
+ downsample_factor = 1
89
+ ratio = 0.0
90
+
91
+ args = tome_info["args"]
92
+
93
+ cur_h, cur_w = original_h // downsample, original_w // downsample
94
+ new_h, new_w = cur_h // downsample_factor, cur_w // downsample_factor
95
+
96
+ if tome_info['timestep'] / 1000 > tome_info['args']['timestep_threshold_switch']:
97
+ merge_method = args["merge_method"]
98
+ else:
99
+ merge_method = args["secondary_merge_method"]
100
+
101
+ if cur_level != "other" and tome_info['timestep'] / 1000 > tome_info['args']['timestep_threshold_stop']:
102
+ if merge_method == "downsample" and downsample_factor > 1:
103
+ m = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h, args["downsample_method"])
104
+ u = lambda x: up_or_downsample(x, new_w, new_h, cur_w, cur_h, args["downsample_method"])
105
+ elif merge_method == "similarity" and ratio > 0.0:
106
+ w = int(math.ceil(original_w / downsample))
107
+ h = int(math.ceil(original_h / downsample))
108
+ r = int(x.shape[1] * ratio)
109
+
110
+ # Re-init the generator if it hasn't already been initialized or device has changed.
111
+ if args["generator"] is None:
112
+ args["generator"] = init_generator(x.device)
113
+ elif args["generator"].device != x.device:
114
+ args["generator"] = init_generator(x.device, fallback=args["generator"])
115
+
116
+ # If the batch size is odd, then it's not possible for prompted and unprompted images to be in the same
117
+ # batch, which causes artifacts with use_rand, so force it to be off.
118
+ use_rand = False if x.shape[0] % 2 == 1 else args["use_rand"]
119
+ m, u = bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r,
120
+ no_rand=not use_rand, generator=args["generator"])
121
+ else:
122
+ m, u = (do_nothing, do_nothing)
123
+ else:
124
+ m, u = (do_nothing, do_nothing)
125
+
126
+ merge_fn, unmerge_fn = (m, u)
127
+
128
+ return merge_fn, unmerge_fn
129
+
130
+
131
+ def bipartite_soft_matching_random2d(metric: torch.Tensor,
132
+ w: int,
133
+ h: int,
134
+ sx: int,
135
+ sy: int,
136
+ r: int,
137
+ no_rand: bool = False,
138
+ generator: torch.Generator = None) -> Tuple[Callable, Callable]:
139
+ """
140
+ Partitions the tokens into src and dst and merges r tokens from src to dst.
141
+ Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
142
+
143
+ Args:
144
+ - metric [B, N, C]: metric to use for similarity
145
+ - w: image width in tokens
146
+ - h: image height in tokens
147
+ - sx: stride in the x dimension for dst, must divide w
148
+ - sy: stride in the y dimension for dst, must divide h
149
+ - r: number of tokens to remove (by merging)
150
+ - no_rand: if true, disable randomness (use top left corner only)
151
+ - rand_seed: if no_rand is false, and if not None, sets random seed.
152
+ """
153
+ B, N, _ = metric.shape
154
+
155
+ if r <= 0:
156
+ return do_nothing, do_nothing
157
+
158
+ with torch.no_grad():
159
+ hsy, wsx = h // sy, w // sx
160
+
161
+ # For each sy by sx kernel, randomly assign one token to be dst and the rest src
162
+ if no_rand:
163
+ rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
164
+ else:
165
+ rand_idx = torch.randint(sy * sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(
166
+ metric.device)
167
+
168
+ # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
169
+ idx_buffer_view = torch.zeros(hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64)
170
+ idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
171
+ idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
172
+
173
+ # Image is not divisible by sx or sy so we need to move it into a new buffer
174
+ if (hsy * sy) < h or (wsx * sx) < w:
175
+ idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
176
+ idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
177
+ else:
178
+ idx_buffer = idx_buffer_view
179
+
180
+ # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
181
+ rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
182
+
183
+ # We're finished with these
184
+ del idx_buffer, idx_buffer_view
185
+
186
+ # rand_idx is currently dst|src, so split them
187
+ num_dst = hsy * wsx
188
+ a_idx = rand_idx[:, num_dst:, :] # src
189
+ b_idx = rand_idx[:, :num_dst, :] # dst
190
+
191
+ def split(x):
192
+ C = x.shape[-1]
193
+ src = torch.gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
194
+ dst = torch.gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
195
+ return src, dst
196
+
197
+ # Cosine similarity between A and B
198
+ metric = metric / metric.norm(dim=-1, keepdim=True)
199
+ a, b = split(metric)
200
+ scores = a @ b.transpose(-1, -2)
201
+
202
+ # Can't reduce more than the # tokens in src
203
+ r = min(a.shape[1], r)
204
+
205
+ # Find the most similar greedily
206
+ node_max, node_idx = scores.max(dim=-1)
207
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
208
+
209
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
210
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
211
+ dst_idx = torch.gather(node_idx[..., None], dim=-2, index=src_idx)
212
+
213
+ def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
214
+ src, dst = split(x)
215
+ n, t1, c = src.shape
216
+
217
+ unm = torch.gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
218
+ src = torch.gather(src, dim=-2, index=src_idx.expand(n, r, c))
219
+ dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
220
+
221
+ return torch.cat([unm, dst], dim=1)
222
+
223
+ def unmerge(x: torch.Tensor) -> torch.Tensor:
224
+ unm_len = unm_idx.shape[1]
225
+ unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
226
+ _, _, c = unm.shape
227
+
228
+ src = torch.gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
229
+
230
+ # Combine back to the original shape
231
+ out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
232
+ out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
233
+ out.scatter_(dim=-2,
234
+ index=torch.gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c),
235
+ src=unm)
236
+ out.scatter_(dim=-2,
237
+ index=torch.gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c),
238
+ src=src)
239
+
240
+ return out
241
+
242
+ return merge, unmerge
243
+
244
+
245
+ class TokenMergeAttentionProcessor:
246
+ def __init__(self):
247
+ # priortize torch2's flash attention, if not fall back to xformers then regular attention
248
+ if torch2_is_available:
249
+ self.attn_method = "torch2"
250
+ elif xformers_is_available:
251
+ self.attn_method = "xformers"
252
+ else:
253
+ self.attn_method = "regular"
254
+
255
+ def torch2_attention(self, attn, query, key, value, attention_mask, batch_size):
256
+ inner_dim=key.shape[-1]
257
+ head_dim = inner_dim // attn.heads
258
+
259
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
260
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
261
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
262
+
263
+ hidden_states = F.scaled_dot_product_attention(
264
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
265
+ )
266
+
267
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
268
+
269
+ return hidden_states
270
+
271
+ def xformers_attention(self, attn, query, key, value, attention_mask, batch_size):
272
+ query = attn.head_to_batch_dim(query).contiguous()
273
+ key = attn.head_to_batch_dim(key).contiguous()
274
+ value = attn.head_to_batch_dim(value).contiguous()
275
+
276
+ if attention_mask is not None:
277
+ attention_mask = attention_mask.reshape(batch_size * attn.heads, -1, attention_mask.shape[-1])
278
+
279
+ hidden_states = xformers.ops.memory_efficient_attention(
280
+ query, key, value, attn_bias=attention_mask, scale=attn.scale
281
+ )
282
+
283
+ hidden_states = attn.batch_to_head_dim(hidden_states)
284
+
285
+ return hidden_states
286
+
287
+
288
+ def regular_attention(self, attn, query, key, value, attention_mask, batch_size):
289
+ query = attn.head_to_batch_dim(query)
290
+ key = attn.head_to_batch_dim(key)
291
+ value = attn.head_to_batch_dim(value)
292
+
293
+ if attention_mask is not None:
294
+ attention_mask = attention_mask.reshape(batch_size * attn.heads, -1, attention_mask.shape[-1])
295
+
296
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
297
+ hidden_states = torch.bmm(attention_probs, value)
298
+ hidden_states = attn.batch_to_head_dim(hidden_states)
299
+
300
+ return hidden_states
301
+
302
+
303
+ def __call__(
304
+ self,
305
+ attn: Attention,
306
+ hidden_states: torch.FloatTensor,
307
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
308
+ attention_mask: Optional[torch.FloatTensor] = None,
309
+ temb: Optional[torch.FloatTensor] = None,
310
+ scale: float = 1.0,
311
+ ) -> torch.FloatTensor:
312
+ residual = hidden_states
313
+ if attn.spatial_norm is not None:
314
+ hidden_states = attn.spatial_norm(hidden_states, temb)
315
+
316
+ input_ndim = hidden_states.ndim
317
+
318
+ if input_ndim == 4:
319
+ batch_size, channel, height, width = hidden_states.shape
320
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
321
+
322
+ batch_size, sequence_length, _ = (
323
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
324
+ )
325
+
326
+ if attention_mask is not None:
327
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
328
+ # scaled_dot_product_attention expects attention_mask shape to be
329
+ # (batch, heads, source_length, target_length)
330
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
331
+
332
+ if attn.group_norm is not None:
333
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
334
+
335
+ args = () if USE_PEFT_BACKEND else (scale,)
336
+
337
+ if self._tome_info['args']['merge_tokens'] == "all":
338
+ merge_fn, unmerge_fn = compute_merge(hidden_states, self._tome_info)
339
+ hidden_states = merge_fn(hidden_states)
340
+
341
+ query = attn.to_q(hidden_states, *args)
342
+
343
+ if encoder_hidden_states is None:
344
+ encoder_hidden_states = hidden_states
345
+ elif attn.norm_cross:
346
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
347
+
348
+ if self._tome_info['args']['merge_tokens'] == "keys/values":
349
+ merge_fn, _ = compute_merge(encoder_hidden_states, self._tome_info)
350
+ encoder_hidden_states = merge_fn(encoder_hidden_states)
351
+
352
+ key = attn.to_k(encoder_hidden_states, *args)
353
+ value = attn.to_v(encoder_hidden_states, *args)
354
+
355
+ if self.attn_method == "torch2":
356
+ hidden_states = self.torch2_attention(attn, query, key, value, attention_mask, batch_size)
357
+ elif self.attn_method == "xformers":
358
+ hidden_states = self.xformers_attention(attn, query, key, value, attention_mask, batch_size)
359
+ else:
360
+ hidden_states = self.regular_attention(attn, query, key, value, attention_mask, batch_size)
361
+
362
+ hidden_states = hidden_states.to(query.dtype)
363
+
364
+ # linear proj
365
+ hidden_states = attn.to_out[0](hidden_states, *args)
366
+ # dropout
367
+ hidden_states = attn.to_out[1](hidden_states)
368
+
369
+ if self._tome_info['args']['merge_tokens'] == "all":
370
+ hidden_states = unmerge_fn(hidden_states)
371
+
372
+ if input_ndim == 4:
373
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
374
+
375
+ if attn.residual_connection:
376
+ hidden_states = hidden_states + residual
377
+
378
+ hidden_states = hidden_states / attn.rescale_output_factor
379
+
380
+ return hidden_states
381
+
382
+
383
+
384
+
385
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ diffusers
2
+ transformers
3
+ accelerate
test_notebook.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db39c8f9f9eea913cade16c3cf45ce0d9a13cc050b5e2564896b100042cdc86b
3
+ size 17164306
utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from merge import TokenMergeAttentionProcessor
3
+ from diffusers.utils.import_utils import is_xformers_available
4
+ from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor, AttnProcessor
5
+ import torch.nn.functional as F
6
+
7
+ if is_xformers_available():
8
+ xformers_is_available = True
9
+ else:
10
+ xformers_is_available = False
11
+
12
+ if hasattr(F, "scaled_dot_product_attention"):
13
+ torch2_is_available = True
14
+ else:
15
+ torch2_is_available = False
16
+
17
+
18
+ def hook_tome_model(model: torch.nn.Module):
19
+ """ Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """
20
+
21
+ def hook(module, args):
22
+ module._tome_info["size"] = (args[0].shape[2], args[0].shape[3])
23
+ module._tome_info["timestep"] = args[1].item()
24
+ return None
25
+
26
+ model._tome_info["hooks"].append(model.register_forward_pre_hook(hook))
27
+
28
+
29
+ def patch_attention_proc(unet, token_merge_args={}):
30
+ unet._tome_info = {
31
+ "size": None,
32
+ "timestep": None,
33
+ "hooks": [],
34
+ "args": {
35
+ "ratio": token_merge_args.get("ratio", 0.5), # ratio of tokens to merge
36
+ "sx": token_merge_args.get("sx", 2), # stride x for sim calculation
37
+ "sy": token_merge_args.get("sy", 2), # stride y for sim calculation
38
+ "use_rand": token_merge_args.get("use_rand", True),
39
+ "generator": None,
40
+
41
+ "merge_tokens": token_merge_args.get("merge_tokens", "keys/values"), # ["all", "keys/values"]
42
+ "merge_method": token_merge_args.get("merge_method", "downsample"), # ["none","similarity", "downsample"]
43
+ "downsample_method": token_merge_args.get("downsample_method", "nearest-exact"),
44
+ # native torch interpolation methods ["nearest", "linear", "bilinear", "bicubic", "nearest-exact"]
45
+ "downsample_factor": token_merge_args.get("downsample_factor", 2), # amount to downsample by
46
+ "timestep_threshold_switch": token_merge_args.get("timestep_threshold_switch", 0.2),
47
+ # timestep to switch to secondary method, 0.2 means 20% steps remaining
48
+ "timestep_threshold_stop": token_merge_args.get("timestep_threshold_stop", 0.0),
49
+ # timestep to stop merging, 0.0 means stop at 0 steps remaining
50
+ "secondary_merge_method": token_merge_args.get("secondary_merge_method", "similarity"),
51
+ # ["none", "similarity", "downsample"]
52
+
53
+ "downsample_factor_level_2": token_merge_args.get("downsample_factor_level_2", 1), # amount to downsample by at the 2nd down block of unet
54
+ "ratio_level_2": token_merge_args.get("ratio_level_2", 0.5), # ratio of tokens to merge at the 2nd down block of unet
55
+ }
56
+ }
57
+ hook_tome_model(unet)
58
+ attn_modules = [module for name, module in unet.named_modules() if module.__class__.__name__ == 'BasicTransformerBlock']
59
+
60
+ for i, module in enumerate(attn_modules):
61
+ module.attn1.processor = TokenMergeAttentionProcessor()
62
+ module.attn1.processor._tome_info = unet._tome_info
63
+
64
+
65
+ def remove_patch(pipe: torch.nn.Module):
66
+ """ Removes a patch from a ToMe Diffusion module if it was already patched. """
67
+
68
+ # this will remove our custom class
69
+ if torch2_is_available:
70
+ for n,m in pipe.unet.named_modules():
71
+ if hasattr(m, "processor"):
72
+ m.processor = AttnProcessor2_0()
73
+
74
+ elif xformers_is_available:
75
+ pipe.enable_xformers_memory_efficient_attention()
76
+
77
+ else:
78
+ for n,m in pipe.unet.named_modules():
79
+ if hasattr(m, "processor"):
80
+ m.processor = AttnProcessor()