Linoy Tsaban commited on
Commit
9113152
1 Parent(s): 1a2c8b5

Create tokenflow_utils.py

Browse files
Files changed (1) hide show
  1. tokenflow_utils.py +448 -0
tokenflow_utils.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+ import torch
3
+ import os
4
+
5
+ from util import isinstance_str, batch_cosine_sim
6
+
7
+ def register_pivotal(diffusion_model, is_pivotal):
8
+ for _, module in diffusion_model.named_modules():
9
+ # If for some reason this has a different name, create an issue and I'll fix it
10
+ if isinstance_str(module, "BasicTransformerBlock"):
11
+ setattr(module, "pivotal_pass", is_pivotal)
12
+
13
+ def register_batch_idx(diffusion_model, batch_idx):
14
+ for _, module in diffusion_model.named_modules():
15
+ # If for some reason this has a different name, create an issue and I'll fix it
16
+ if isinstance_str(module, "BasicTransformerBlock"):
17
+ setattr(module, "batch_idx", batch_idx)
18
+
19
+
20
+ def register_time(model, t):
21
+ conv_module = model.unet.up_blocks[1].resnets[1]
22
+ setattr(conv_module, 't', t)
23
+ down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
24
+ up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
25
+ for res in up_res_dict:
26
+ for block in up_res_dict[res]:
27
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
28
+ setattr(module, 't', t)
29
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn2
30
+ setattr(module, 't', t)
31
+ for res in down_res_dict:
32
+ for block in down_res_dict[res]:
33
+ module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
34
+ setattr(module, 't', t)
35
+ module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn2
36
+ setattr(module, 't', t)
37
+ module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1
38
+ setattr(module, 't', t)
39
+ module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn2
40
+ setattr(module, 't', t)
41
+
42
+
43
+ def load_source_latents_t(t, latents_path):
44
+ latents_t_path = os.path.join(latents_path, f'noisy_latents_{t}.pt')
45
+ assert os.path.exists(latents_t_path), f'Missing latents at t {t} path {latents_t_path}'
46
+ latents = torch.load(latents_t_path)
47
+ return latents
48
+
49
+ def register_conv_injection(model, injection_schedule):
50
+ def conv_forward(self):
51
+ def forward(input_tensor, temb):
52
+ hidden_states = input_tensor
53
+
54
+ hidden_states = self.norm1(hidden_states)
55
+ hidden_states = self.nonlinearity(hidden_states)
56
+
57
+ if self.upsample is not None:
58
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
59
+ if hidden_states.shape[0] >= 64:
60
+ input_tensor = input_tensor.contiguous()
61
+ hidden_states = hidden_states.contiguous()
62
+ input_tensor = self.upsample(input_tensor)
63
+ hidden_states = self.upsample(hidden_states)
64
+ elif self.downsample is not None:
65
+ input_tensor = self.downsample(input_tensor)
66
+ hidden_states = self.downsample(hidden_states)
67
+
68
+ hidden_states = self.conv1(hidden_states)
69
+
70
+ if temb is not None:
71
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
72
+
73
+ if temb is not None and self.time_embedding_norm == "default":
74
+ hidden_states = hidden_states + temb
75
+
76
+ hidden_states = self.norm2(hidden_states)
77
+
78
+ if temb is not None and self.time_embedding_norm == "scale_shift":
79
+ scale, shift = torch.chunk(temb, 2, dim=1)
80
+ hidden_states = hidden_states * (1 + scale) + shift
81
+
82
+ hidden_states = self.nonlinearity(hidden_states)
83
+
84
+ hidden_states = self.dropout(hidden_states)
85
+ hidden_states = self.conv2(hidden_states)
86
+ if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
87
+ source_batch_size = int(hidden_states.shape[0] // 3)
88
+ # inject unconditional
89
+ hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
90
+ # inject conditional
91
+ hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]
92
+
93
+ if self.conv_shortcut is not None:
94
+ input_tensor = self.conv_shortcut(input_tensor)
95
+
96
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
97
+
98
+ return output_tensor
99
+
100
+ return forward
101
+
102
+ conv_module = model.unet.up_blocks[1].resnets[1]
103
+ conv_module.forward = conv_forward(conv_module)
104
+ setattr(conv_module, 'injection_schedule', injection_schedule)
105
+
106
+ def register_extended_attention_pnp(model, injection_schedule):
107
+ def sa_forward(self):
108
+ to_out = self.to_out
109
+ if type(to_out) is torch.nn.modules.container.ModuleList:
110
+ to_out = self.to_out[0]
111
+ else:
112
+ to_out = self.to_out
113
+
114
+ def forward(x, encoder_hidden_states=None):
115
+ batch_size, sequence_length, dim = x.shape
116
+ h = self.heads
117
+ n_frames = batch_size // 3
118
+ is_cross = encoder_hidden_states is not None
119
+ encoder_hidden_states = encoder_hidden_states if is_cross else x
120
+ q = self.to_q(x)
121
+ k = self.to_k(encoder_hidden_states)
122
+ v = self.to_v(encoder_hidden_states)
123
+
124
+ if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
125
+ # inject unconditional
126
+ q[n_frames:2 * n_frames] = q[:n_frames]
127
+ k[n_frames:2 * n_frames] = k[:n_frames]
128
+ # inject conditional
129
+ q[2 * n_frames:] = q[:n_frames]
130
+ k[2 * n_frames:] = k[:n_frames]
131
+
132
+ k_source = k[:n_frames]
133
+ k_uncond = k[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
134
+ k_cond = k[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
135
+
136
+ v_source = v[:n_frames]
137
+ v_uncond = v[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
138
+ v_cond = v[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
139
+
140
+ q_source = self.head_to_batch_dim(q[:n_frames])
141
+ q_uncond = self.head_to_batch_dim(q[n_frames:2 * n_frames])
142
+ q_cond = self.head_to_batch_dim(q[2 * n_frames:])
143
+ k_source = self.head_to_batch_dim(k_source)
144
+ k_uncond = self.head_to_batch_dim(k_uncond)
145
+ k_cond = self.head_to_batch_dim(k_cond)
146
+ v_source = self.head_to_batch_dim(v_source)
147
+ v_uncond = self.head_to_batch_dim(v_uncond)
148
+ v_cond = self.head_to_batch_dim(v_cond)
149
+
150
+
151
+ q_src = q_source.view(n_frames, h, sequence_length, dim // h)
152
+ k_src = k_source.view(n_frames, h, sequence_length, dim // h)
153
+ v_src = v_source.view(n_frames, h, sequence_length, dim // h)
154
+ q_uncond = q_uncond.view(n_frames, h, sequence_length, dim // h)
155
+ k_uncond = k_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
156
+ v_uncond = v_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
157
+ q_cond = q_cond.view(n_frames, h, sequence_length, dim // h)
158
+ k_cond = k_cond.view(n_frames, h, sequence_length * n_frames, dim // h)
159
+ v_cond = v_cond.view(n_frames, h, sequence_length * n_frames, dim // h)
160
+
161
+ out_source_all = []
162
+ out_uncond_all = []
163
+ out_cond_all = []
164
+
165
+ single_batch = n_frames<=12
166
+ b = n_frames if single_batch else 1
167
+
168
+ for frame in range(0, n_frames, b):
169
+ out_source = []
170
+ out_uncond = []
171
+ out_cond = []
172
+ for j in range(h):
173
+ sim_source_b = torch.bmm(q_src[frame: frame+ b, j], k_src[frame: frame+ b, j].transpose(-1, -2)) * self.scale
174
+ sim_uncond_b = torch.bmm(q_uncond[frame: frame+ b, j], k_uncond[frame: frame+ b, j].transpose(-1, -2)) * self.scale
175
+ sim_cond = torch.bmm(q_cond[frame: frame+ b, j], k_cond[frame: frame+ b, j].transpose(-1, -2)) * self.scale
176
+
177
+ out_source.append(torch.bmm(sim_source_b.softmax(dim=-1), v_src[frame: frame+ b, j]))
178
+ out_uncond.append(torch.bmm(sim_uncond_b.softmax(dim=-1), v_uncond[frame: frame+ b, j]))
179
+ out_cond.append(torch.bmm(sim_cond.softmax(dim=-1), v_cond[frame: frame+ b, j]))
180
+
181
+ out_source = torch.cat(out_source, dim=0)
182
+ out_uncond = torch.cat(out_uncond, dim=0)
183
+ out_cond = torch.cat(out_cond, dim=0)
184
+ if single_batch:
185
+ out_source = out_source.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
186
+ out_uncond = out_uncond.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
187
+ out_cond = out_cond.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
188
+ out_source_all.append(out_source)
189
+ out_uncond_all.append(out_uncond)
190
+ out_cond_all.append(out_cond)
191
+
192
+ out_source = torch.cat(out_source_all, dim=0)
193
+ out_uncond = torch.cat(out_uncond_all, dim=0)
194
+ out_cond = torch.cat(out_cond_all, dim=0)
195
+
196
+ out = torch.cat([out_source, out_uncond, out_cond], dim=0)
197
+ out = self.batch_to_head_dim(out)
198
+
199
+ return to_out(out)
200
+
201
+ return forward
202
+
203
+ for _, module in model.unet.named_modules():
204
+ if isinstance_str(module, "BasicTransformerBlock"):
205
+ module.attn1.forward = sa_forward(module.attn1)
206
+ setattr(module.attn1, 'injection_schedule', [])
207
+
208
+ res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
209
+ # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
210
+ for res in res_dict:
211
+ for block in res_dict[res]:
212
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
213
+ module.forward = sa_forward(module)
214
+ setattr(module, 'injection_schedule', injection_schedule)
215
+
216
+ def register_extended_attention(model):
217
+ def sa_forward(self):
218
+ to_out = self.to_out
219
+ if type(to_out) is torch.nn.modules.container.ModuleList:
220
+ to_out = self.to_out[0]
221
+ else:
222
+ to_out = self.to_out
223
+
224
+ def forward(x, encoder_hidden_states=None):
225
+ batch_size, sequence_length, dim = x.shape
226
+ h = self.heads
227
+ n_frames = batch_size // 3
228
+ is_cross = encoder_hidden_states is not None
229
+ encoder_hidden_states = encoder_hidden_states if is_cross else x
230
+ q = self.to_q(x)
231
+ k = self.to_k(encoder_hidden_states)
232
+ v = self.to_v(encoder_hidden_states)
233
+
234
+ k_source = k[:n_frames]
235
+ k_uncond = k[n_frames: 2*n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
236
+ k_cond = k[2*n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
237
+ v_source = v[:n_frames]
238
+ v_uncond = v[n_frames:2*n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
239
+ v_cond = v[2*n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
240
+
241
+ q_source = self.head_to_batch_dim(q[:n_frames])
242
+ q_uncond = self.head_to_batch_dim(q[n_frames: 2*n_frames])
243
+ q_cond = self.head_to_batch_dim(q[2 * n_frames:])
244
+ k_source = self.head_to_batch_dim(k_source)
245
+ k_uncond = self.head_to_batch_dim(k_uncond)
246
+ k_cond = self.head_to_batch_dim(k_cond)
247
+ v_source = self.head_to_batch_dim(v_source)
248
+ v_uncond = self.head_to_batch_dim(v_uncond)
249
+ v_cond = self.head_to_batch_dim(v_cond)
250
+
251
+ out_source = []
252
+ out_uncond = []
253
+ out_cond = []
254
+
255
+ q_src = q_source.view(n_frames, h, sequence_length, dim // h)
256
+ k_src = k_source.view(n_frames, h, sequence_length, dim // h)
257
+ v_src = v_source.view(n_frames, h, sequence_length, dim // h)
258
+ q_uncond = q_uncond.view(n_frames, h, sequence_length, dim // h)
259
+ k_uncond = k_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
260
+ v_uncond = v_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
261
+ q_cond = q_cond.view(n_frames, h, sequence_length, dim // h)
262
+ k_cond = k_cond.view(n_frames, h, sequence_length * n_frames, dim // h)
263
+ v_cond = v_cond.view(n_frames, h, sequence_length * n_frames, dim // h)
264
+
265
+ for j in range(h):
266
+ sim_source_b = torch.bmm(q_src[:, j], k_src[:, j].transpose(-1, -2)) * self.scale
267
+ sim_uncond_b = torch.bmm(q_uncond[:, j], k_uncond[:, j].transpose(-1, -2)) * self.scale
268
+ sim_cond = torch.bmm(q_cond[:, j], k_cond[:, j].transpose(-1, -2)) * self.scale
269
+
270
+ out_source.append(torch.bmm(sim_source_b.softmax(dim=-1), v_src[:, j]))
271
+ out_uncond.append(torch.bmm(sim_uncond_b.softmax(dim=-1), v_uncond[:, j]))
272
+ out_cond.append(torch.bmm(sim_cond.softmax(dim=-1), v_cond[:, j]))
273
+
274
+ out_source = torch.cat(out_source, dim=0).view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
275
+ out_uncond = torch.cat(out_uncond, dim=0).view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
276
+ out_cond = torch.cat(out_cond, dim=0).view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
277
+
278
+ out = torch.cat([out_source, out_uncond, out_cond], dim=0)
279
+ out = self.batch_to_head_dim(out)
280
+
281
+ return to_out(out)
282
+
283
+ return forward
284
+
285
+ for _, module in model.unet.named_modules():
286
+ if isinstance_str(module, "BasicTransformerBlock"):
287
+ module.attn1.forward = sa_forward(module.attn1)
288
+
289
+ res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
290
+ # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
291
+ for res in res_dict:
292
+ for block in res_dict[res]:
293
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
294
+ module.forward = sa_forward(module)
295
+
296
+ def make_tokenflow_attention_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
297
+
298
+ class TokenFlowBlock(block_class):
299
+
300
+ def forward(
301
+ self,
302
+ hidden_states,
303
+ attention_mask=None,
304
+ encoder_hidden_states=None,
305
+ encoder_attention_mask=None,
306
+ timestep=None,
307
+ cross_attention_kwargs=None,
308
+ class_labels=None,
309
+ ) -> torch.Tensor:
310
+
311
+ batch_size, sequence_length, dim = hidden_states.shape
312
+ n_frames = batch_size // 3
313
+ mid_idx = n_frames // 2
314
+ hidden_states = hidden_states.view(3, n_frames, sequence_length, dim)
315
+
316
+ if self.use_ada_layer_norm:
317
+ norm_hidden_states = self.norm1(hidden_states, timestep)
318
+ elif self.use_ada_layer_norm_zero:
319
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
320
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
321
+ )
322
+ else:
323
+ norm_hidden_states = self.norm1(hidden_states)
324
+
325
+ norm_hidden_states = norm_hidden_states.view(3, n_frames, sequence_length, dim)
326
+ if self.pivotal_pass:
327
+ self.pivot_hidden_states = norm_hidden_states
328
+ else:
329
+ idx1 = []
330
+ idx2 = []
331
+ batch_idxs = [self.batch_idx]
332
+ if self.batch_idx > 0:
333
+ batch_idxs.append(self.batch_idx - 1)
334
+
335
+ sim = batch_cosine_sim(norm_hidden_states[0].reshape(-1, dim),
336
+ self.pivot_hidden_states[0][batch_idxs].reshape(-1, dim))
337
+ if len(batch_idxs) == 2:
338
+ sim1, sim2 = sim.chunk(2, dim=1)
339
+ # sim: n_frames * seq_len, len(batch_idxs) * seq_len
340
+ idx1.append(sim1.argmax(dim=-1)) # n_frames * seq_len
341
+ idx2.append(sim2.argmax(dim=-1)) # n_frames * seq_len
342
+ else:
343
+ idx1.append(sim.argmax(dim=-1))
344
+ idx1 = torch.stack(idx1 * 3, dim=0) # 3, n_frames * seq_len
345
+ idx1 = idx1.squeeze(1)
346
+ if len(batch_idxs) == 2:
347
+ idx2 = torch.stack(idx2 * 3, dim=0) # 3, n_frames * seq_len
348
+ idx2 = idx2.squeeze(1)
349
+
350
+ # 1. Self-Attention
351
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
352
+ if self.pivotal_pass:
353
+ # norm_hidden_states.shape = 3, n_frames * seq_len, dim
354
+ self.attn_output = self.attn1(
355
+ norm_hidden_states.view(batch_size, sequence_length, dim),
356
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
357
+ **cross_attention_kwargs,
358
+ )
359
+ # 3, n_frames * seq_len, dim - > 3 * n_frames, seq_len, dim
360
+ self.kf_attn_output = self.attn_output
361
+ else:
362
+ batch_kf_size, _, _ = self.kf_attn_output.shape
363
+ self.attn_output = self.kf_attn_output.view(3, batch_kf_size // 3, sequence_length, dim)[:,
364
+ batch_idxs] # 3, n_frames, seq_len, dim --> 3, len(batch_idxs), seq_len, dim
365
+ if self.use_ada_layer_norm_zero:
366
+ self.attn_output = gate_msa.unsqueeze(1) * self.attn_output
367
+
368
+ # gather values from attn_output, using idx as indices, and get a tensor of shape 3, n_frames, seq_len, dim
369
+ if not self.pivotal_pass:
370
+ if len(batch_idxs) == 2:
371
+ attn_1, attn_2 = self.attn_output[:, 0], self.attn_output[:, 1]
372
+ attn_output1 = attn_1.gather(dim=1, index=idx1.unsqueeze(-1).repeat(1, 1, dim))
373
+ attn_output2 = attn_2.gather(dim=1, index=idx2.unsqueeze(-1).repeat(1, 1, dim))
374
+
375
+ s = torch.arange(0, n_frames).to(idx1.device) + batch_idxs[0] * n_frames
376
+ # distance from the pivot
377
+ p1 = batch_idxs[0] * n_frames + n_frames // 2
378
+ p2 = batch_idxs[1] * n_frames + n_frames // 2
379
+ d1 = torch.abs(s - p1)
380
+ d2 = torch.abs(s - p2)
381
+ # weight
382
+ w1 = d2 / (d1 + d2)
383
+ w1 = torch.sigmoid(w1)
384
+
385
+ w1 = w1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).repeat(3, 1, sequence_length, dim)
386
+ attn_output1 = attn_output1.view(3, n_frames, sequence_length, dim)
387
+ attn_output2 = attn_output2.view(3, n_frames, sequence_length, dim)
388
+ attn_output = w1 * attn_output1 + (1 - w1) * attn_output2
389
+ else:
390
+ attn_output = self.attn_output[:,0].gather(dim=1, index=idx1.unsqueeze(-1).repeat(1, 1, dim))
391
+
392
+ attn_output = attn_output.reshape(
393
+ batch_size, sequence_length, dim) # 3 * n_frames, seq_len, dim
394
+ else:
395
+ attn_output = self.attn_output
396
+ hidden_states = hidden_states.reshape(batch_size, sequence_length, dim) # 3 * n_frames, seq_len, dim
397
+ hidden_states = attn_output + hidden_states
398
+
399
+ if self.attn2 is not None:
400
+ norm_hidden_states = (
401
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
402
+ )
403
+
404
+ # 2. Cross-Attention
405
+ attn_output = self.attn2(
406
+ norm_hidden_states,
407
+ encoder_hidden_states=encoder_hidden_states,
408
+ attention_mask=encoder_attention_mask,
409
+ **cross_attention_kwargs,
410
+ )
411
+ hidden_states = attn_output + hidden_states
412
+
413
+ # 3. Feed-forward
414
+ norm_hidden_states = self.norm3(hidden_states)
415
+
416
+ if self.use_ada_layer_norm_zero:
417
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
418
+
419
+
420
+ ff_output = self.ff(norm_hidden_states)
421
+
422
+ if self.use_ada_layer_norm_zero:
423
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
424
+
425
+ hidden_states = ff_output + hidden_states
426
+
427
+ return hidden_states
428
+
429
+ return TokenFlowBlock
430
+
431
+
432
+ def set_tokenflow(
433
+ model: torch.nn.Module):
434
+ """
435
+ Sets the tokenflow attention blocks in a model.
436
+ """
437
+
438
+ for _, module in model.named_modules():
439
+ if isinstance_str(module, "BasicTransformerBlock"):
440
+ make_tokenflow_block_fn = make_tokenflow_attention_block
441
+ module.__class__ = make_tokenflow_block_fn(module.__class__)
442
+
443
+ # Something needed for older versions of diffusers
444
+ if not hasattr(module, "use_ada_layer_norm_zero"):
445
+ module.use_ada_layer_norm = False
446
+ module.use_ada_layer_norm_zero = False
447
+
448
+ return model