omer11a commited on
Commit
8fea73b
1 Parent(s): 39ce2cf

Improved memory efficiency

Browse files
Files changed (2) hide show
  1. bounded_attention.py +44 -42
  2. injection_utils.py +1 -1
bounded_attention.py CHANGED
@@ -44,7 +44,7 @@ class BoundedAttention(injection_utils.AttentionBase):
44
  pca_rank=None,
45
  num_clusters=None,
46
  num_clusters_per_box=3,
47
- max_resolution=32,
48
  map_dir=None,
49
  debug=False,
50
  delta_debug_attention_steps=20,
@@ -95,8 +95,10 @@ class BoundedAttention(injection_utils.AttentionBase):
95
  self.self_foreground_values = []
96
  self.cross_background_values = []
97
  self.self_background_values = []
98
- self.cross_maps = []
99
- self.self_maps = []
 
 
100
  self.self_masks = None
101
 
102
  def clear_values(self, include_maps=False):
@@ -107,16 +109,15 @@ class BoundedAttention(injection_utils.AttentionBase):
107
  self.self_background_values,
108
  )
109
 
110
- if include_maps:
111
- lists = (
112
- *all_values,
113
- self.cross_maps,
114
- self.self_maps,
115
- )
116
-
117
  for values in lists:
118
  values.clear()
119
 
 
 
 
 
 
 
120
  def before_step(self):
121
  self.clear_values()
122
  if self.cur_step == 0:
@@ -137,37 +138,31 @@ class BoundedAttention(injection_utils.AttentionBase):
137
  else:
138
  masks = self._hide_other_subjects_from_subjects(batch_size // 2, n, dtype, device)
139
 
140
- if int(n ** 0.5) > self.max_resolution:
141
- q = q.reshape(batch_size, num_heads, n, -1)
142
- k = k.reshape(batch_size, num_heads, d, -1)
143
- v = v.reshape(batch_size, num_heads, d, -1)
144
- out = F.scaled_dot_product_attention(q, k, v, attn_mask=masks)
145
- out = out.reshape(batch_size * num_heads, n, -1)
146
- else:
147
- sim = torch.einsum('b i d, b j d -> b i j', q, k) * kwargs['scale']
148
- attn = sim.softmax(-1)
149
- self._display_attention_maps(attn, is_cross, num_heads)
150
- sim = sim.reshape(batch_size, num_heads, n, d) + masks
151
- attn = sim.reshape(-1, n, d).softmax(-1)
152
- self._save(attn, is_cross, num_heads)
153
- self._display_attention_maps(attn, is_cross, num_heads, prefix='masked')
154
- self._debug_hook(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
155
- out = torch.bmm(attn, v)
156
-
157
- out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
158
- return out
159
 
160
  def update_loss(self, forward_pass, latents, i):
161
  if i >= self.max_guidance_iter:
162
  return latents
163
 
164
  step_size = self.start_step_size + self.step_size_coef * i
165
- updated_latents = latents
166
 
167
  self.optimized = True
168
  normalized_loss = torch.tensor(10000)
169
  with torch.enable_grad():
170
- latents = updated_latents = updated_latents.clone().detach().requires_grad_(True)
171
  for guidance_iter in range(self.max_guidance_iter_per_step):
172
  if normalized_loss < self.loss_stopping_value:
173
  break
@@ -178,8 +173,8 @@ class BoundedAttention(injection_utils.AttentionBase):
178
  self.cur_step = cur_step
179
 
180
  loss, normalized_loss = self._compute_loss()
181
- grad_cond = torch.autograd.grad(loss, [updated_latents])[0]
182
- latents = updated_latents = updated_latents - step_size * grad_cond
183
  if self.debug:
184
  print(f'Loss at step={i}, iter={guidance_iter}: {normalized_loss}')
185
  grad_norms = grad_cond.flatten(start_dim=2).norm(dim=1)
@@ -301,13 +296,21 @@ class BoundedAttention(injection_utils.AttentionBase):
301
 
302
  if is_cross:
303
  attn = attn[..., self.leading_token_indices]
304
- mask_maps = self.cross_maps
 
305
  else:
306
- mask_maps = self.self_maps
 
307
 
308
- mask_maps.append(attn.mean(dim=1)) # mean over heads
309
- if self.cur_step > 0:
310
- mask_maps.pop(0) # throw away old maps
 
 
 
 
 
 
311
 
312
  def _save_loss_values(self, attn, is_cross):
313
  if (
@@ -404,7 +407,7 @@ class BoundedAttention(injection_utils.AttentionBase):
404
  return self_masks.flatten(start_dim=2).bool()
405
 
406
  def _cluster_self_maps(self): # b s n
407
- self_maps = self._aggregate_maps(self.self_maps) # b n m
408
  if self.pca_rank is not None:
409
  dtype = self_maps.dtype
410
  _, _, eigen_vectors = torch.pca_lowrank(self_maps.float(), self.pca_rank)
@@ -442,7 +445,7 @@ class BoundedAttention(injection_utils.AttentionBase):
442
  return self_masks
443
 
444
  def _obtain_cross_masks(self, resolution, scale=10):
445
- maps = self._aggregate_maps(self.cross_maps, resolution=resolution) # b n k
446
  maps = F.sigmoid(scale * (maps - self.cross_mask_threshold))
447
  maps = self._normalize_maps(maps, reduce_min=True)
448
  maps = maps.transpose(1, 2) # b k n
@@ -466,8 +469,7 @@ class BoundedAttention(injection_utils.AttentionBase):
466
 
467
  return masks
468
 
469
- def _aggregate_maps(self, maps, resolution=None): # b n k
470
- maps = torch.stack(maps).mean(0) # mean over layers
471
  if resolution is not None:
472
  b, n, k = maps.shape
473
  original_resolution = int(n ** 0.5)
 
44
  pca_rank=None,
45
  num_clusters=None,
46
  num_clusters_per_box=3,
47
+ max_resolution=None,
48
  map_dir=None,
49
  debug=False,
50
  delta_debug_attention_steps=20,
 
95
  self.self_foreground_values = []
96
  self.cross_background_values = []
97
  self.self_background_values = []
98
+ self.mean_cross_map = 0
99
+ self.num_cross_maps = 0
100
+ self.mean_self_map = 0
101
+ self.num_self_maps = 0
102
  self.self_masks = None
103
 
104
  def clear_values(self, include_maps=False):
 
109
  self.self_background_values,
110
  )
111
 
 
 
 
 
 
 
 
112
  for values in lists:
113
  values.clear()
114
 
115
+ if include_maps:
116
+ self.mean_cross_map = 0
117
+ self.num_cross_maps = 0
118
+ self.mean_self_map = 0
119
+ self.num_self_maps = 0
120
+
121
  def before_step(self):
122
  self.clear_values()
123
  if self.cur_step == 0:
 
138
  else:
139
  masks = self._hide_other_subjects_from_subjects(batch_size // 2, n, dtype, device)
140
 
141
+ resolution = int(n ** 0.5)
142
+ if (self.max_resolution is not None) and (resolution > self.max_resolution):
143
+ return super().forward(q, k, v, is_cross, place_in_unet, num_heads, mask=masks)
144
+
145
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * kwargs['scale']
146
+ attn = sim.softmax(-1)
147
+ self._display_attention_maps(attn, is_cross, num_heads)
148
+ sim = sim.reshape(batch_size, num_heads, n, d) + masks
149
+ attn = sim.reshape(-1, n, d).softmax(-1)
150
+ self._save(attn, is_cross, num_heads)
151
+ self._display_attention_maps(attn, is_cross, num_heads, prefix='masked')
152
+ self._debug_hook(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
153
+ out = torch.bmm(attn, v)
154
+ return einops.rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
 
 
 
 
 
155
 
156
  def update_loss(self, forward_pass, latents, i):
157
  if i >= self.max_guidance_iter:
158
  return latents
159
 
160
  step_size = self.start_step_size + self.step_size_coef * i
 
161
 
162
  self.optimized = True
163
  normalized_loss = torch.tensor(10000)
164
  with torch.enable_grad():
165
+ latents = latents.clone().detach().requires_grad_(True)
166
  for guidance_iter in range(self.max_guidance_iter_per_step):
167
  if normalized_loss < self.loss_stopping_value:
168
  break
 
173
  self.cur_step = cur_step
174
 
175
  loss, normalized_loss = self._compute_loss()
176
+ grad_cond = torch.autograd.grad(loss, [latents])[0]
177
+ latents = latents - step_size * grad_cond
178
  if self.debug:
179
  print(f'Loss at step={i}, iter={guidance_iter}: {normalized_loss}')
180
  grad_norms = grad_cond.flatten(start_dim=2).norm(dim=1)
 
296
 
297
  if is_cross:
298
  attn = attn[..., self.leading_token_indices]
299
+ mean_map = self.mean_cross_map
300
+ num_maps = self.num_cross_maps
301
  else:
302
+ mean_map = self.mean_self_map
303
+ num_maps = self.num_self_maps
304
 
305
+ num_maps += 1
306
+ attn = attn.mean(dim=1) # mean over heads
307
+ mean_map = ((num_maps - 1) / num_maps) * mean_map + (1 / num_maps) * attn
308
+ if is_cross:
309
+ self.mean_cross_map = mean_map
310
+ self.num_cross_maps = num_maps
311
+ else:
312
+ self.mean_self_map = mean_map
313
+ self.num_self_maps = num_maps
314
 
315
  def _save_loss_values(self, attn, is_cross):
316
  if (
 
407
  return self_masks.flatten(start_dim=2).bool()
408
 
409
  def _cluster_self_maps(self): # b s n
410
+ self_maps = self._compute_maps(self.mean_self_map) # b n m
411
  if self.pca_rank is not None:
412
  dtype = self_maps.dtype
413
  _, _, eigen_vectors = torch.pca_lowrank(self_maps.float(), self.pca_rank)
 
445
  return self_masks
446
 
447
  def _obtain_cross_masks(self, resolution, scale=10):
448
+ maps = self._compute_maps(self.mean_cross_map, resolution=resolution) # b n k
449
  maps = F.sigmoid(scale * (maps - self.cross_mask_threshold))
450
  maps = self._normalize_maps(maps, reduce_min=True)
451
  maps = maps.transpose(1, 2) # b k n
 
469
 
470
  return masks
471
 
472
+ def _compute_maps(self, maps, resolution=None): # b n k
 
473
  if resolution is not None:
474
  b, n, k = maps.shape
475
  original_resolution = int(n ** 0.5)
injection_utils.py CHANGED
@@ -35,7 +35,7 @@ class AttentionBase:
35
 
36
  return out
37
 
38
- def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
39
  batch_size = q.size(0) // num_heads
40
  n = q.size(1)
41
  d = k.size(1)
 
35
 
36
  return out
37
 
38
+ def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
39
  batch_size = q.size(0) // num_heads
40
  n = q.size(1)
41
  d = k.size(1)