omer11a commited on
Commit
056b358
1 Parent(s): 1127ecd

Improved memory requirements

Browse files
Files changed (2) hide show
  1. bounded_attention.py +36 -30
  2. injection_utils.py +15 -122
bounded_attention.py CHANGED
@@ -44,6 +44,7 @@ class BoundedAttention(injection_utils.AttentionBase):
44
  pca_rank=None,
45
  num_clusters=None,
46
  num_clusters_per_box=3,
 
47
  map_dir=None,
48
  debug=False,
49
  delta_debug_attention_steps=20,
@@ -81,6 +82,7 @@ class BoundedAttention(injection_utils.AttentionBase):
81
  self.clustering = KMeans(n_clusters=num_clusters, num_init=100)
82
  self.centers = None
83
 
 
84
  self.map_dir = map_dir
85
  self.debug = debug
86
  self.delta_debug_attention_steps = delta_debug_attention_steps
@@ -124,24 +126,34 @@ class BoundedAttention(injection_utils.AttentionBase):
124
  self.clear_values(include_maps=True)
125
  super().reset()
126
 
127
- def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
128
- self._display_attention_maps(attn, is_cross, num_heads)
129
-
130
- _, n, d = sim.shape
131
- sim_u, sim_c = sim.reshape(-1, num_heads, n, d).chunk(2) # b h n d
 
132
  if is_cross:
133
- sim_c = self._hide_other_subjects_from_tokens(sim_c)
134
  else:
135
- sim_u = self._hide_other_subjects_from_subjects(sim_u)
136
- sim_c = self._hide_other_subjects_from_subjects(sim_c)
137
-
138
- sim = torch.cat((sim_u, sim_c)).reshape(-1, n, d)
139
- attn = sim.softmax(-1)
140
- self._save(attn, is_cross, num_heads)
141
- self._display_attention_maps(attn, is_cross, num_heads, prefix='masked')
142
- self._debug_hook(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
 
 
 
 
 
 
 
 
 
 
143
 
144
- out = torch.bmm(attn, v)
145
  out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
146
  return out
147
 
@@ -235,16 +247,13 @@ class BoundedAttention(injection_utils.AttentionBase):
235
  references = references.reshape(-1, *references_unconditional.shape[2:])
236
  return batch, references
237
 
238
- def _hide_other_subjects_from_tokens(self, sim): # b h i j
239
- dtype = sim.dtype
240
- device = sim.device
241
- batch_size = sim.size(0)
242
- resolution = int(sim.size(2) ** 0.5)
243
  subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
244
  include_background = self.optimized or (not self.mask_cross_during_guidance and self.cur_step < self.max_guidance_iter_per_step)
245
  subject_masks = torch.logical_or(subject_masks, background_masks.unsqueeze(1)) if include_background else subject_masks
246
- min_value = torch.finfo(sim.dtype).min
247
- sim_masks = torch.zeros_like(sim[:, 0, :, :]) # b i j
248
  for token_indices in (*self.subject_token_indices, self.filter_token_indices):
249
  sim_masks[:, :, token_indices] = min_value
250
 
@@ -257,16 +266,13 @@ class BoundedAttention(injection_utils.AttentionBase):
257
  for batch_index, background_mask in zip(range(batch_size), background_masks):
258
  sim_masks[batch_index, background_mask, self.eos_token_index] = min_value
259
 
260
- return sim + sim_masks.unsqueeze(1)
261
 
262
- def _hide_other_subjects_from_subjects(self, sim): # b h i j
263
- dtype = sim.dtype
264
- device = sim.device
265
- batch_size = sim.size(0)
266
- resolution = int(sim.size(2) ** 0.5)
267
  subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
268
  min_value = torch.finfo(dtype).min
269
- sim_masks = torch.zeros_like(sim[:, 0, :, :]) # b i j
270
  for batch_index, background_mask in zip(range(batch_size), background_masks):
271
  sim_masks[batch_index, ~background_mask, ~background_mask] = min_value
272
 
@@ -276,7 +282,7 @@ class BoundedAttention(injection_utils.AttentionBase):
276
  condition = torch.logical_or(subject_sim_mask == 0, subject_mask.unsqueeze(0))
277
  sim_masks[batch_index, subject_mask] = torch.where(condition, 0, min_value).to(dtype=dtype)
278
 
279
- return sim + sim_masks.unsqueeze(1)
280
 
281
  def _save(self, attn, is_cross, num_heads):
282
  _, attn = attn.chunk(2)
 
44
  pca_rank=None,
45
  num_clusters=None,
46
  num_clusters_per_box=3,
47
+ max_resolution=32,
48
  map_dir=None,
49
  debug=False,
50
  delta_debug_attention_steps=20,
 
82
  self.clustering = KMeans(n_clusters=num_clusters, num_init=100)
83
  self.centers = None
84
 
85
+ self.max_resolution = max_resolution
86
  self.map_dir = map_dir
87
  self.debug = debug
88
  self.delta_debug_attention_steps = delta_debug_attention_steps
 
126
  self.clear_values(include_maps=True)
127
  super().reset()
128
 
129
+ def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
130
+ batch_size = q.size(0) // num_heads
131
+ n = q.size(1)
132
+ d = k.size(1)
133
+ dtype = q.dtype
134
+ device = q.device
135
  if is_cross:
136
+ masks = self._hide_other_subjects_from_tokens(batch_size // 2, n, d, dtype, device)
137
  else:
138
+ masks = self._hide_other_subjects_from_subjects(batch_size // 2, n, dtype, device)
139
+
140
+ if int(n ** 0.5) > self.max_resolution:
141
+ q = q.reshape(batch_size, num_heads, n, -1)
142
+ k = k.reshape(batch_size, num_heads, d, -1)
143
+ v = v.reshape(batch_size, num_heads, d, -1)
144
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=masks)
145
+ out = out.reshape(batch_size * num_heads, n, -1)
146
+ else:
147
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * kwargs['scale']
148
+ attn = sim.softmax(-1)
149
+ self._display_attention_maps(attn, is_cross, num_heads)
150
+ sim = sim.reshape(batch_size, num_heads, n, d) + masks
151
+ attn = sim.reshape(-1, n, d).softmax(-1)
152
+ self._save(attn, is_cross, num_heads)
153
+ self._display_attention_maps(attn, is_cross, num_heads, prefix='masked')
154
+ self._debug_hook(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
155
+ out = torch.bmm(attn, v)
156
 
 
157
  out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
158
  return out
159
 
 
247
  references = references.reshape(-1, *references_unconditional.shape[2:])
248
  return batch, references
249
 
250
+ def _hide_other_subjects_from_tokens(self, batch_size, n, d, dtype, device): # b h i j
251
+ resolution = int(n ** 0.5)
 
 
 
252
  subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
253
  include_background = self.optimized or (not self.mask_cross_during_guidance and self.cur_step < self.max_guidance_iter_per_step)
254
  subject_masks = torch.logical_or(subject_masks, background_masks.unsqueeze(1)) if include_background else subject_masks
255
+ min_value = torch.finfo(dtype).min
256
+ sim_masks = torch.zeros((batch_size, n, d), dtype=dtype, device=device) # b i j
257
  for token_indices in (*self.subject_token_indices, self.filter_token_indices):
258
  sim_masks[:, :, token_indices] = min_value
259
 
 
266
  for batch_index, background_mask in zip(range(batch_size), background_masks):
267
  sim_masks[batch_index, background_mask, self.eos_token_index] = min_value
268
 
269
+ return torch.cat((torch.zeros_like(sim_masks), sim_masks)).unsqueeze(1)
270
 
271
+ def _hide_other_subjects_from_subjects(self, batch_size, n, dtype, device): # b h i j
272
+ resolution = int(n ** 0.5)
 
 
 
273
  subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
274
  min_value = torch.finfo(dtype).min
275
+ sim_masks = torch.zeros((batch_size, n, n), dtype=dtype, device=device) # b i j
276
  for batch_index, background_mask in zip(range(batch_size), background_masks):
277
  sim_masks[batch_index, ~background_mask, ~background_mask] = min_value
278
 
 
282
  condition = torch.logical_or(subject_sim_mask == 0, subject_mask.unsqueeze(0))
283
  sim_masks[batch_index, subject_mask] = torch.where(condition, 0, min_value).to(dtype=dtype)
284
 
285
+ return torch.cat((sim_masks, sim_masks)).unsqueeze(1)
286
 
287
  def _save(self, attn, is_cross, num_heads):
288
  _, attn = attn.chunk(2)
injection_utils.py CHANGED
@@ -22,21 +22,29 @@ class AttentionBase:
22
  def after_step(self):
23
  pass
24
 
25
- def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
26
  if self.cur_att_layer == 0:
27
  self.before_step()
28
 
29
- out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
30
  self.cur_att_layer += 1
31
  if self.cur_att_layer == self.num_att_layers:
32
  self.cur_att_layer = 0
33
  self.cur_step += 1
34
- # after step
35
  self.after_step()
 
36
  return out
37
 
38
  def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
39
- out = torch.einsum('b i j, b j d -> b i d', attn, v)
 
 
 
 
 
 
 
 
40
  out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
41
  return out
42
 
@@ -45,42 +53,6 @@ class AttentionBase:
45
  self.cur_att_layer = 0
46
 
47
 
48
- class AttentionStore(AttentionBase):
49
- def __init__(self, res=[32], min_step=0, max_step=1000):
50
- super().__init__()
51
- self.res = res
52
- self.min_step = min_step
53
- self.max_step = max_step
54
- self.valid_steps = 0
55
-
56
- self.self_attns = [] # store the all attns
57
- self.cross_attns = []
58
-
59
- self.self_attns_step = [] # store the attns in each step
60
- self.cross_attns_step = []
61
-
62
- def after_step(self):
63
- if self.cur_step > self.min_step and self.cur_step < self.max_step:
64
- self.valid_steps += 1
65
- if len(self.self_attns) == 0:
66
- self.self_attns = self.self_attns_step
67
- self.cross_attns = self.cross_attns_step
68
- else:
69
- for i in range(len(self.self_attns)):
70
- self.self_attns[i] += self.self_attns_step[i]
71
- self.cross_attns[i] += self.cross_attns_step[i]
72
- self.self_attns_step.clear()
73
- self.cross_attns_step.clear()
74
-
75
- def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
76
- if attn.shape[1] <= 64 ** 2: # avoid OOM
77
- if is_cross:
78
- self.cross_attns_step.append(attn)
79
- else:
80
- self.self_attns_step.append(attn)
81
- return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
82
-
83
-
84
  def regiter_attention_editor_diffusers(model, editor: AttentionBase):
85
  """
86
  Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
@@ -109,21 +81,9 @@ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
109
  k = self.to_k(context)
110
  v = self.to_v(context)
111
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
112
-
113
- sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
114
-
115
- if mask is not None:
116
- mask = rearrange(mask, 'b ... -> b (...)')
117
- max_neg_value = -torch.finfo(sim.dtype).max
118
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
119
- mask = mask[:, None, :].repeat(h, 1, 1)
120
- sim.masked_fill_(~mask, max_neg_value)
121
-
122
- attn = sim.softmax(dim=-1)
123
- # the only difference
124
  out = editor(
125
- q, k, v, sim, attn, is_cross, place_in_unet,
126
- self.heads, scale=self.scale)
127
 
128
  return to_out(out)
129
 
@@ -146,74 +106,7 @@ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
146
  cross_att_count += register_editor(net, 0, "mid")
147
  elif "up" in net_name:
148
  cross_att_count += register_editor(net, 0, "up")
 
149
  editor.num_att_layers = cross_att_count
150
  editor.model = model
151
  model.editor = editor
152
-
153
-
154
- def regiter_attention_editor_ldm(model, editor: AttentionBase):
155
- """
156
- Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt]
157
- """
158
- def ca_forward(self, place_in_unet):
159
- def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
160
- """
161
- The attention is similar to the original implementation of LDM CrossAttention class
162
- except adding some modifications on the attention
163
- """
164
- if encoder_hidden_states is not None:
165
- context = encoder_hidden_states
166
- if attention_mask is not None:
167
- mask = attention_mask
168
-
169
- to_out = self.to_out
170
- if isinstance(to_out, nn.modules.container.ModuleList):
171
- to_out = self.to_out[0]
172
- else:
173
- to_out = self.to_out
174
-
175
- h = self.heads
176
- q = self.to_q(x)
177
- is_cross = context is not None
178
- context = context if is_cross else x
179
- k = self.to_k(context)
180
- v = self.to_v(context)
181
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
182
-
183
- sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
184
-
185
- if mask is not None:
186
- mask = rearrange(mask, 'b ... -> b (...)')
187
- max_neg_value = -torch.finfo(sim.dtype).max
188
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
189
- mask = mask[:, None, :].repeat(h, 1, 1)
190
- sim.masked_fill_(~mask, max_neg_value)
191
-
192
- attn = sim.softmax(dim=-1)
193
- # the only difference
194
- out = editor(
195
- q, k, v, sim, attn, is_cross, place_in_unet,
196
- self.heads, scale=self.scale)
197
-
198
- return to_out(out)
199
-
200
- return forward
201
-
202
- def register_editor(net, count, place_in_unet):
203
- for name, subnet in net.named_children():
204
- if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer
205
- net.forward = ca_forward(net, place_in_unet)
206
- return count + 1
207
- elif hasattr(net, 'children'):
208
- count = register_editor(subnet, count, place_in_unet)
209
- return count
210
-
211
- cross_att_count = 0
212
- for net_name, net in model.model.diffusion_model.named_children():
213
- if "input" in net_name:
214
- cross_att_count += register_editor(net, 0, "input")
215
- elif "middle" in net_name:
216
- cross_att_count += register_editor(net, 0, "middle")
217
- elif "output" in net_name:
218
- cross_att_count += register_editor(net, 0, "output")
219
- editor.num_att_layers = cross_att_count
 
22
  def after_step(self):
23
  pass
24
 
25
+ def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
26
  if self.cur_att_layer == 0:
27
  self.before_step()
28
 
29
+ out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
30
  self.cur_att_layer += 1
31
  if self.cur_att_layer == self.num_att_layers:
32
  self.cur_att_layer = 0
33
  self.cur_step += 1
 
34
  self.after_step()
35
+
36
  return out
37
 
38
  def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
39
+ batch_size = q.size(0) // num_heads
40
+ n = q.size(1)
41
+ d = k.size(1)
42
+
43
+ q = q.reshape(batch_size, num_heads, n, -1)
44
+ k = k.reshape(batch_size, num_heads, d, -1)
45
+ v = v.reshape(batch_size, num_heads, d, -1)
46
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=kwargs['mask'])
47
+ out = out.reshape(batch_size * num_heads, n, -1)
48
  out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
49
  return out
50
 
 
53
  self.cur_att_layer = 0
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def regiter_attention_editor_diffusers(model, editor: AttentionBase):
57
  """
58
  Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
 
81
  k = self.to_k(context)
82
  v = self.to_v(context)
83
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
 
 
 
 
 
 
 
 
 
 
 
 
84
  out = editor(
85
+ q, k, v, is_cross, place_in_unet,
86
+ self.heads, scale=self.scale, mask=mask)
87
 
88
  return to_out(out)
89
 
 
106
  cross_att_count += register_editor(net, 0, "mid")
107
  elif "up" in net_name:
108
  cross_att_count += register_editor(net, 0, "up")
109
+
110
  editor.num_att_layers = cross_att_count
111
  editor.model = model
112
  model.editor = editor