Spaces:
Running
on
Zero
Running
on
Zero
Improved memory requirements
Browse files- bounded_attention.py +36 -30
- injection_utils.py +15 -122
bounded_attention.py
CHANGED
@@ -44,6 +44,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
44 |
pca_rank=None,
|
45 |
num_clusters=None,
|
46 |
num_clusters_per_box=3,
|
|
|
47 |
map_dir=None,
|
48 |
debug=False,
|
49 |
delta_debug_attention_steps=20,
|
@@ -81,6 +82,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
81 |
self.clustering = KMeans(n_clusters=num_clusters, num_init=100)
|
82 |
self.centers = None
|
83 |
|
|
|
84 |
self.map_dir = map_dir
|
85 |
self.debug = debug
|
86 |
self.delta_debug_attention_steps = delta_debug_attention_steps
|
@@ -124,24 +126,34 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
124 |
self.clear_values(include_maps=True)
|
125 |
super().reset()
|
126 |
|
127 |
-
def forward(self, q, k, v,
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
132 |
if is_cross:
|
133 |
-
|
134 |
else:
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
-
out = torch.bmm(attn, v)
|
145 |
out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
146 |
return out
|
147 |
|
@@ -235,16 +247,13 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
235 |
references = references.reshape(-1, *references_unconditional.shape[2:])
|
236 |
return batch, references
|
237 |
|
238 |
-
def _hide_other_subjects_from_tokens(self,
|
239 |
-
|
240 |
-
device = sim.device
|
241 |
-
batch_size = sim.size(0)
|
242 |
-
resolution = int(sim.size(2) ** 0.5)
|
243 |
subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
|
244 |
include_background = self.optimized or (not self.mask_cross_during_guidance and self.cur_step < self.max_guidance_iter_per_step)
|
245 |
subject_masks = torch.logical_or(subject_masks, background_masks.unsqueeze(1)) if include_background else subject_masks
|
246 |
-
min_value = torch.finfo(
|
247 |
-
sim_masks = torch.
|
248 |
for token_indices in (*self.subject_token_indices, self.filter_token_indices):
|
249 |
sim_masks[:, :, token_indices] = min_value
|
250 |
|
@@ -257,16 +266,13 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
257 |
for batch_index, background_mask in zip(range(batch_size), background_masks):
|
258 |
sim_masks[batch_index, background_mask, self.eos_token_index] = min_value
|
259 |
|
260 |
-
return
|
261 |
|
262 |
-
def _hide_other_subjects_from_subjects(self,
|
263 |
-
|
264 |
-
device = sim.device
|
265 |
-
batch_size = sim.size(0)
|
266 |
-
resolution = int(sim.size(2) ** 0.5)
|
267 |
subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
|
268 |
min_value = torch.finfo(dtype).min
|
269 |
-
sim_masks = torch.
|
270 |
for batch_index, background_mask in zip(range(batch_size), background_masks):
|
271 |
sim_masks[batch_index, ~background_mask, ~background_mask] = min_value
|
272 |
|
@@ -276,7 +282,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
276 |
condition = torch.logical_or(subject_sim_mask == 0, subject_mask.unsqueeze(0))
|
277 |
sim_masks[batch_index, subject_mask] = torch.where(condition, 0, min_value).to(dtype=dtype)
|
278 |
|
279 |
-
return
|
280 |
|
281 |
def _save(self, attn, is_cross, num_heads):
|
282 |
_, attn = attn.chunk(2)
|
|
|
44 |
pca_rank=None,
|
45 |
num_clusters=None,
|
46 |
num_clusters_per_box=3,
|
47 |
+
max_resolution=32,
|
48 |
map_dir=None,
|
49 |
debug=False,
|
50 |
delta_debug_attention_steps=20,
|
|
|
82 |
self.clustering = KMeans(n_clusters=num_clusters, num_init=100)
|
83 |
self.centers = None
|
84 |
|
85 |
+
self.max_resolution = max_resolution
|
86 |
self.map_dir = map_dir
|
87 |
self.debug = debug
|
88 |
self.delta_debug_attention_steps = delta_debug_attention_steps
|
|
|
126 |
self.clear_values(include_maps=True)
|
127 |
super().reset()
|
128 |
|
129 |
+
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
|
130 |
+
batch_size = q.size(0) // num_heads
|
131 |
+
n = q.size(1)
|
132 |
+
d = k.size(1)
|
133 |
+
dtype = q.dtype
|
134 |
+
device = q.device
|
135 |
if is_cross:
|
136 |
+
masks = self._hide_other_subjects_from_tokens(batch_size // 2, n, d, dtype, device)
|
137 |
else:
|
138 |
+
masks = self._hide_other_subjects_from_subjects(batch_size // 2, n, dtype, device)
|
139 |
+
|
140 |
+
if int(n ** 0.5) > self.max_resolution:
|
141 |
+
q = q.reshape(batch_size, num_heads, n, -1)
|
142 |
+
k = k.reshape(batch_size, num_heads, d, -1)
|
143 |
+
v = v.reshape(batch_size, num_heads, d, -1)
|
144 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=masks)
|
145 |
+
out = out.reshape(batch_size * num_heads, n, -1)
|
146 |
+
else:
|
147 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * kwargs['scale']
|
148 |
+
attn = sim.softmax(-1)
|
149 |
+
self._display_attention_maps(attn, is_cross, num_heads)
|
150 |
+
sim = sim.reshape(batch_size, num_heads, n, d) + masks
|
151 |
+
attn = sim.reshape(-1, n, d).softmax(-1)
|
152 |
+
self._save(attn, is_cross, num_heads)
|
153 |
+
self._display_attention_maps(attn, is_cross, num_heads, prefix='masked')
|
154 |
+
self._debug_hook(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
155 |
+
out = torch.bmm(attn, v)
|
156 |
|
|
|
157 |
out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
158 |
return out
|
159 |
|
|
|
247 |
references = references.reshape(-1, *references_unconditional.shape[2:])
|
248 |
return batch, references
|
249 |
|
250 |
+
def _hide_other_subjects_from_tokens(self, batch_size, n, d, dtype, device): # b h i j
|
251 |
+
resolution = int(n ** 0.5)
|
|
|
|
|
|
|
252 |
subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
|
253 |
include_background = self.optimized or (not self.mask_cross_during_guidance and self.cur_step < self.max_guidance_iter_per_step)
|
254 |
subject_masks = torch.logical_or(subject_masks, background_masks.unsqueeze(1)) if include_background else subject_masks
|
255 |
+
min_value = torch.finfo(dtype).min
|
256 |
+
sim_masks = torch.zeros((batch_size, n, d), dtype=dtype, device=device) # b i j
|
257 |
for token_indices in (*self.subject_token_indices, self.filter_token_indices):
|
258 |
sim_masks[:, :, token_indices] = min_value
|
259 |
|
|
|
266 |
for batch_index, background_mask in zip(range(batch_size), background_masks):
|
267 |
sim_masks[batch_index, background_mask, self.eos_token_index] = min_value
|
268 |
|
269 |
+
return torch.cat((torch.zeros_like(sim_masks), sim_masks)).unsqueeze(1)
|
270 |
|
271 |
+
def _hide_other_subjects_from_subjects(self, batch_size, n, dtype, device): # b h i j
|
272 |
+
resolution = int(n ** 0.5)
|
|
|
|
|
|
|
273 |
subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
|
274 |
min_value = torch.finfo(dtype).min
|
275 |
+
sim_masks = torch.zeros((batch_size, n, n), dtype=dtype, device=device) # b i j
|
276 |
for batch_index, background_mask in zip(range(batch_size), background_masks):
|
277 |
sim_masks[batch_index, ~background_mask, ~background_mask] = min_value
|
278 |
|
|
|
282 |
condition = torch.logical_or(subject_sim_mask == 0, subject_mask.unsqueeze(0))
|
283 |
sim_masks[batch_index, subject_mask] = torch.where(condition, 0, min_value).to(dtype=dtype)
|
284 |
|
285 |
+
return torch.cat((sim_masks, sim_masks)).unsqueeze(1)
|
286 |
|
287 |
def _save(self, attn, is_cross, num_heads):
|
288 |
_, attn = attn.chunk(2)
|
injection_utils.py
CHANGED
@@ -22,21 +22,29 @@ class AttentionBase:
|
|
22 |
def after_step(self):
|
23 |
pass
|
24 |
|
25 |
-
def __call__(self, q, k, v,
|
26 |
if self.cur_att_layer == 0:
|
27 |
self.before_step()
|
28 |
|
29 |
-
out = self.forward(q, k, v,
|
30 |
self.cur_att_layer += 1
|
31 |
if self.cur_att_layer == self.num_att_layers:
|
32 |
self.cur_att_layer = 0
|
33 |
self.cur_step += 1
|
34 |
-
# after step
|
35 |
self.after_step()
|
|
|
36 |
return out
|
37 |
|
38 |
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
41 |
return out
|
42 |
|
@@ -45,42 +53,6 @@ class AttentionBase:
|
|
45 |
self.cur_att_layer = 0
|
46 |
|
47 |
|
48 |
-
class AttentionStore(AttentionBase):
|
49 |
-
def __init__(self, res=[32], min_step=0, max_step=1000):
|
50 |
-
super().__init__()
|
51 |
-
self.res = res
|
52 |
-
self.min_step = min_step
|
53 |
-
self.max_step = max_step
|
54 |
-
self.valid_steps = 0
|
55 |
-
|
56 |
-
self.self_attns = [] # store the all attns
|
57 |
-
self.cross_attns = []
|
58 |
-
|
59 |
-
self.self_attns_step = [] # store the attns in each step
|
60 |
-
self.cross_attns_step = []
|
61 |
-
|
62 |
-
def after_step(self):
|
63 |
-
if self.cur_step > self.min_step and self.cur_step < self.max_step:
|
64 |
-
self.valid_steps += 1
|
65 |
-
if len(self.self_attns) == 0:
|
66 |
-
self.self_attns = self.self_attns_step
|
67 |
-
self.cross_attns = self.cross_attns_step
|
68 |
-
else:
|
69 |
-
for i in range(len(self.self_attns)):
|
70 |
-
self.self_attns[i] += self.self_attns_step[i]
|
71 |
-
self.cross_attns[i] += self.cross_attns_step[i]
|
72 |
-
self.self_attns_step.clear()
|
73 |
-
self.cross_attns_step.clear()
|
74 |
-
|
75 |
-
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
76 |
-
if attn.shape[1] <= 64 ** 2: # avoid OOM
|
77 |
-
if is_cross:
|
78 |
-
self.cross_attns_step.append(attn)
|
79 |
-
else:
|
80 |
-
self.self_attns_step.append(attn)
|
81 |
-
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
82 |
-
|
83 |
-
|
84 |
def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
85 |
"""
|
86 |
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
|
@@ -109,21 +81,9 @@ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
|
109 |
k = self.to_k(context)
|
110 |
v = self.to_v(context)
|
111 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
112 |
-
|
113 |
-
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
114 |
-
|
115 |
-
if mask is not None:
|
116 |
-
mask = rearrange(mask, 'b ... -> b (...)')
|
117 |
-
max_neg_value = -torch.finfo(sim.dtype).max
|
118 |
-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
119 |
-
mask = mask[:, None, :].repeat(h, 1, 1)
|
120 |
-
sim.masked_fill_(~mask, max_neg_value)
|
121 |
-
|
122 |
-
attn = sim.softmax(dim=-1)
|
123 |
-
# the only difference
|
124 |
out = editor(
|
125 |
-
q, k, v,
|
126 |
-
self.heads, scale=self.scale)
|
127 |
|
128 |
return to_out(out)
|
129 |
|
@@ -146,74 +106,7 @@ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
|
146 |
cross_att_count += register_editor(net, 0, "mid")
|
147 |
elif "up" in net_name:
|
148 |
cross_att_count += register_editor(net, 0, "up")
|
|
|
149 |
editor.num_att_layers = cross_att_count
|
150 |
editor.model = model
|
151 |
model.editor = editor
|
152 |
-
|
153 |
-
|
154 |
-
def regiter_attention_editor_ldm(model, editor: AttentionBase):
|
155 |
-
"""
|
156 |
-
Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt]
|
157 |
-
"""
|
158 |
-
def ca_forward(self, place_in_unet):
|
159 |
-
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
|
160 |
-
"""
|
161 |
-
The attention is similar to the original implementation of LDM CrossAttention class
|
162 |
-
except adding some modifications on the attention
|
163 |
-
"""
|
164 |
-
if encoder_hidden_states is not None:
|
165 |
-
context = encoder_hidden_states
|
166 |
-
if attention_mask is not None:
|
167 |
-
mask = attention_mask
|
168 |
-
|
169 |
-
to_out = self.to_out
|
170 |
-
if isinstance(to_out, nn.modules.container.ModuleList):
|
171 |
-
to_out = self.to_out[0]
|
172 |
-
else:
|
173 |
-
to_out = self.to_out
|
174 |
-
|
175 |
-
h = self.heads
|
176 |
-
q = self.to_q(x)
|
177 |
-
is_cross = context is not None
|
178 |
-
context = context if is_cross else x
|
179 |
-
k = self.to_k(context)
|
180 |
-
v = self.to_v(context)
|
181 |
-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
182 |
-
|
183 |
-
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
184 |
-
|
185 |
-
if mask is not None:
|
186 |
-
mask = rearrange(mask, 'b ... -> b (...)')
|
187 |
-
max_neg_value = -torch.finfo(sim.dtype).max
|
188 |
-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
189 |
-
mask = mask[:, None, :].repeat(h, 1, 1)
|
190 |
-
sim.masked_fill_(~mask, max_neg_value)
|
191 |
-
|
192 |
-
attn = sim.softmax(dim=-1)
|
193 |
-
# the only difference
|
194 |
-
out = editor(
|
195 |
-
q, k, v, sim, attn, is_cross, place_in_unet,
|
196 |
-
self.heads, scale=self.scale)
|
197 |
-
|
198 |
-
return to_out(out)
|
199 |
-
|
200 |
-
return forward
|
201 |
-
|
202 |
-
def register_editor(net, count, place_in_unet):
|
203 |
-
for name, subnet in net.named_children():
|
204 |
-
if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer
|
205 |
-
net.forward = ca_forward(net, place_in_unet)
|
206 |
-
return count + 1
|
207 |
-
elif hasattr(net, 'children'):
|
208 |
-
count = register_editor(subnet, count, place_in_unet)
|
209 |
-
return count
|
210 |
-
|
211 |
-
cross_att_count = 0
|
212 |
-
for net_name, net in model.model.diffusion_model.named_children():
|
213 |
-
if "input" in net_name:
|
214 |
-
cross_att_count += register_editor(net, 0, "input")
|
215 |
-
elif "middle" in net_name:
|
216 |
-
cross_att_count += register_editor(net, 0, "middle")
|
217 |
-
elif "output" in net_name:
|
218 |
-
cross_att_count += register_editor(net, 0, "output")
|
219 |
-
editor.num_att_layers = cross_att_count
|
|
|
22 |
def after_step(self):
|
23 |
pass
|
24 |
|
25 |
+
def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
|
26 |
if self.cur_att_layer == 0:
|
27 |
self.before_step()
|
28 |
|
29 |
+
out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
|
30 |
self.cur_att_layer += 1
|
31 |
if self.cur_att_layer == self.num_att_layers:
|
32 |
self.cur_att_layer = 0
|
33 |
self.cur_step += 1
|
|
|
34 |
self.after_step()
|
35 |
+
|
36 |
return out
|
37 |
|
38 |
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
39 |
+
batch_size = q.size(0) // num_heads
|
40 |
+
n = q.size(1)
|
41 |
+
d = k.size(1)
|
42 |
+
|
43 |
+
q = q.reshape(batch_size, num_heads, n, -1)
|
44 |
+
k = k.reshape(batch_size, num_heads, d, -1)
|
45 |
+
v = v.reshape(batch_size, num_heads, d, -1)
|
46 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=kwargs['mask'])
|
47 |
+
out = out.reshape(batch_size * num_heads, n, -1)
|
48 |
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
49 |
return out
|
50 |
|
|
|
53 |
self.cur_att_layer = 0
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
57 |
"""
|
58 |
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
|
|
|
81 |
k = self.to_k(context)
|
82 |
v = self.to_v(context)
|
83 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
out = editor(
|
85 |
+
q, k, v, is_cross, place_in_unet,
|
86 |
+
self.heads, scale=self.scale, mask=mask)
|
87 |
|
88 |
return to_out(out)
|
89 |
|
|
|
106 |
cross_att_count += register_editor(net, 0, "mid")
|
107 |
elif "up" in net_name:
|
108 |
cross_att_count += register_editor(net, 0, "up")
|
109 |
+
|
110 |
editor.num_att_layers = cross_att_count
|
111 |
editor.model = model
|
112 |
model.editor = editor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|