Spaces:
Running
on
Zero
Running
on
Zero
Improved memory efficiency
Browse files- bounded_attention.py +44 -42
- injection_utils.py +1 -1
bounded_attention.py
CHANGED
@@ -44,7 +44,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
44 |
pca_rank=None,
|
45 |
num_clusters=None,
|
46 |
num_clusters_per_box=3,
|
47 |
-
max_resolution=
|
48 |
map_dir=None,
|
49 |
debug=False,
|
50 |
delta_debug_attention_steps=20,
|
@@ -95,8 +95,10 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
95 |
self.self_foreground_values = []
|
96 |
self.cross_background_values = []
|
97 |
self.self_background_values = []
|
98 |
-
self.
|
99 |
-
self.
|
|
|
|
|
100 |
self.self_masks = None
|
101 |
|
102 |
def clear_values(self, include_maps=False):
|
@@ -107,16 +109,15 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
107 |
self.self_background_values,
|
108 |
)
|
109 |
|
110 |
-
if include_maps:
|
111 |
-
lists = (
|
112 |
-
*all_values,
|
113 |
-
self.cross_maps,
|
114 |
-
self.self_maps,
|
115 |
-
)
|
116 |
-
|
117 |
for values in lists:
|
118 |
values.clear()
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
def before_step(self):
|
121 |
self.clear_values()
|
122 |
if self.cur_step == 0:
|
@@ -137,37 +138,31 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
137 |
else:
|
138 |
masks = self._hide_other_subjects_from_subjects(batch_size // 2, n, dtype, device)
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
self._debug_hook(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
155 |
-
out = torch.bmm(attn, v)
|
156 |
-
|
157 |
-
out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
158 |
-
return out
|
159 |
|
160 |
def update_loss(self, forward_pass, latents, i):
|
161 |
if i >= self.max_guidance_iter:
|
162 |
return latents
|
163 |
|
164 |
step_size = self.start_step_size + self.step_size_coef * i
|
165 |
-
updated_latents = latents
|
166 |
|
167 |
self.optimized = True
|
168 |
normalized_loss = torch.tensor(10000)
|
169 |
with torch.enable_grad():
|
170 |
-
latents =
|
171 |
for guidance_iter in range(self.max_guidance_iter_per_step):
|
172 |
if normalized_loss < self.loss_stopping_value:
|
173 |
break
|
@@ -178,8 +173,8 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
178 |
self.cur_step = cur_step
|
179 |
|
180 |
loss, normalized_loss = self._compute_loss()
|
181 |
-
grad_cond = torch.autograd.grad(loss, [
|
182 |
-
latents =
|
183 |
if self.debug:
|
184 |
print(f'Loss at step={i}, iter={guidance_iter}: {normalized_loss}')
|
185 |
grad_norms = grad_cond.flatten(start_dim=2).norm(dim=1)
|
@@ -301,13 +296,21 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
301 |
|
302 |
if is_cross:
|
303 |
attn = attn[..., self.leading_token_indices]
|
304 |
-
|
|
|
305 |
else:
|
306 |
-
|
|
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
def _save_loss_values(self, attn, is_cross):
|
313 |
if (
|
@@ -404,7 +407,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
404 |
return self_masks.flatten(start_dim=2).bool()
|
405 |
|
406 |
def _cluster_self_maps(self): # b s n
|
407 |
-
self_maps = self.
|
408 |
if self.pca_rank is not None:
|
409 |
dtype = self_maps.dtype
|
410 |
_, _, eigen_vectors = torch.pca_lowrank(self_maps.float(), self.pca_rank)
|
@@ -442,7 +445,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
442 |
return self_masks
|
443 |
|
444 |
def _obtain_cross_masks(self, resolution, scale=10):
|
445 |
-
maps = self.
|
446 |
maps = F.sigmoid(scale * (maps - self.cross_mask_threshold))
|
447 |
maps = self._normalize_maps(maps, reduce_min=True)
|
448 |
maps = maps.transpose(1, 2) # b k n
|
@@ -466,8 +469,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
466 |
|
467 |
return masks
|
468 |
|
469 |
-
def
|
470 |
-
maps = torch.stack(maps).mean(0) # mean over layers
|
471 |
if resolution is not None:
|
472 |
b, n, k = maps.shape
|
473 |
original_resolution = int(n ** 0.5)
|
|
|
44 |
pca_rank=None,
|
45 |
num_clusters=None,
|
46 |
num_clusters_per_box=3,
|
47 |
+
max_resolution=None,
|
48 |
map_dir=None,
|
49 |
debug=False,
|
50 |
delta_debug_attention_steps=20,
|
|
|
95 |
self.self_foreground_values = []
|
96 |
self.cross_background_values = []
|
97 |
self.self_background_values = []
|
98 |
+
self.mean_cross_map = 0
|
99 |
+
self.num_cross_maps = 0
|
100 |
+
self.mean_self_map = 0
|
101 |
+
self.num_self_maps = 0
|
102 |
self.self_masks = None
|
103 |
|
104 |
def clear_values(self, include_maps=False):
|
|
|
109 |
self.self_background_values,
|
110 |
)
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
for values in lists:
|
113 |
values.clear()
|
114 |
|
115 |
+
if include_maps:
|
116 |
+
self.mean_cross_map = 0
|
117 |
+
self.num_cross_maps = 0
|
118 |
+
self.mean_self_map = 0
|
119 |
+
self.num_self_maps = 0
|
120 |
+
|
121 |
def before_step(self):
|
122 |
self.clear_values()
|
123 |
if self.cur_step == 0:
|
|
|
138 |
else:
|
139 |
masks = self._hide_other_subjects_from_subjects(batch_size // 2, n, dtype, device)
|
140 |
|
141 |
+
resolution = int(n ** 0.5)
|
142 |
+
if (self.max_resolution is not None) and (resolution > self.max_resolution):
|
143 |
+
return super().forward(q, k, v, is_cross, place_in_unet, num_heads, mask=masks)
|
144 |
+
|
145 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * kwargs['scale']
|
146 |
+
attn = sim.softmax(-1)
|
147 |
+
self._display_attention_maps(attn, is_cross, num_heads)
|
148 |
+
sim = sim.reshape(batch_size, num_heads, n, d) + masks
|
149 |
+
attn = sim.reshape(-1, n, d).softmax(-1)
|
150 |
+
self._save(attn, is_cross, num_heads)
|
151 |
+
self._display_attention_maps(attn, is_cross, num_heads, prefix='masked')
|
152 |
+
self._debug_hook(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
153 |
+
out = torch.bmm(attn, v)
|
154 |
+
return einops.rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
def update_loss(self, forward_pass, latents, i):
|
157 |
if i >= self.max_guidance_iter:
|
158 |
return latents
|
159 |
|
160 |
step_size = self.start_step_size + self.step_size_coef * i
|
|
|
161 |
|
162 |
self.optimized = True
|
163 |
normalized_loss = torch.tensor(10000)
|
164 |
with torch.enable_grad():
|
165 |
+
latents = latents.clone().detach().requires_grad_(True)
|
166 |
for guidance_iter in range(self.max_guidance_iter_per_step):
|
167 |
if normalized_loss < self.loss_stopping_value:
|
168 |
break
|
|
|
173 |
self.cur_step = cur_step
|
174 |
|
175 |
loss, normalized_loss = self._compute_loss()
|
176 |
+
grad_cond = torch.autograd.grad(loss, [latents])[0]
|
177 |
+
latents = latents - step_size * grad_cond
|
178 |
if self.debug:
|
179 |
print(f'Loss at step={i}, iter={guidance_iter}: {normalized_loss}')
|
180 |
grad_norms = grad_cond.flatten(start_dim=2).norm(dim=1)
|
|
|
296 |
|
297 |
if is_cross:
|
298 |
attn = attn[..., self.leading_token_indices]
|
299 |
+
mean_map = self.mean_cross_map
|
300 |
+
num_maps = self.num_cross_maps
|
301 |
else:
|
302 |
+
mean_map = self.mean_self_map
|
303 |
+
num_maps = self.num_self_maps
|
304 |
|
305 |
+
num_maps += 1
|
306 |
+
attn = attn.mean(dim=1) # mean over heads
|
307 |
+
mean_map = ((num_maps - 1) / num_maps) * mean_map + (1 / num_maps) * attn
|
308 |
+
if is_cross:
|
309 |
+
self.mean_cross_map = mean_map
|
310 |
+
self.num_cross_maps = num_maps
|
311 |
+
else:
|
312 |
+
self.mean_self_map = mean_map
|
313 |
+
self.num_self_maps = num_maps
|
314 |
|
315 |
def _save_loss_values(self, attn, is_cross):
|
316 |
if (
|
|
|
407 |
return self_masks.flatten(start_dim=2).bool()
|
408 |
|
409 |
def _cluster_self_maps(self): # b s n
|
410 |
+
self_maps = self._compute_maps(self.mean_self_map) # b n m
|
411 |
if self.pca_rank is not None:
|
412 |
dtype = self_maps.dtype
|
413 |
_, _, eigen_vectors = torch.pca_lowrank(self_maps.float(), self.pca_rank)
|
|
|
445 |
return self_masks
|
446 |
|
447 |
def _obtain_cross_masks(self, resolution, scale=10):
|
448 |
+
maps = self._compute_maps(self.mean_cross_map, resolution=resolution) # b n k
|
449 |
maps = F.sigmoid(scale * (maps - self.cross_mask_threshold))
|
450 |
maps = self._normalize_maps(maps, reduce_min=True)
|
451 |
maps = maps.transpose(1, 2) # b k n
|
|
|
469 |
|
470 |
return masks
|
471 |
|
472 |
+
def _compute_maps(self, maps, resolution=None): # b n k
|
|
|
473 |
if resolution is not None:
|
474 |
b, n, k = maps.shape
|
475 |
original_resolution = int(n ** 0.5)
|
injection_utils.py
CHANGED
@@ -35,7 +35,7 @@ class AttentionBase:
|
|
35 |
|
36 |
return out
|
37 |
|
38 |
-
def forward(self, q, k, v,
|
39 |
batch_size = q.size(0) // num_heads
|
40 |
n = q.size(1)
|
41 |
d = k.size(1)
|
|
|
35 |
|
36 |
return out
|
37 |
|
38 |
+
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
|
39 |
batch_size = q.size(0) // num_heads
|
40 |
n = q.size(1)
|
41 |
d = k.size(1)
|