hideosnes commited on
Commit
1fd342e
·
verified ·
1 Parent(s): 5c45fb4

Create attention_processor.py

Browse files
Files changed (1) hide show
  1. ip_adapter/attention_processor.py +562 -0
ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class AttnProcessor(nn.Module):
8
+ r"""
9
+ Default processor for performing attention-related computations.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ hidden_size=None,
15
+ cross_attention_dim=None,
16
+ ):
17
+ super().__init__()
18
+
19
+ def __call__(
20
+ self,
21
+ attn,
22
+ hidden_states,
23
+ encoder_hidden_states=None,
24
+ attention_mask=None,
25
+ temb=None,
26
+ ):
27
+ residual = hidden_states
28
+
29
+ if attn.spatial_norm is not None:
30
+ hidden_states = attn.spatial_norm(hidden_states, temb)
31
+
32
+ input_ndim = hidden_states.ndim
33
+
34
+ if input_ndim == 4:
35
+ batch_size, channel, height, width = hidden_states.shape
36
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37
+
38
+ batch_size, sequence_length, _ = (
39
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40
+ )
41
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
42
+
43
+ if attn.group_norm is not None:
44
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
45
+
46
+ query = attn.to_q(hidden_states)
47
+
48
+ if encoder_hidden_states is None:
49
+ encoder_hidden_states = hidden_states
50
+ elif attn.norm_cross:
51
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
52
+
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+
56
+ query = attn.head_to_batch_dim(query)
57
+ key = attn.head_to_batch_dim(key)
58
+ value = attn.head_to_batch_dim(value)
59
+
60
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
61
+ hidden_states = torch.bmm(attention_probs, value)
62
+ hidden_states = attn.batch_to_head_dim(hidden_states)
63
+
64
+ # linear proj
65
+ hidden_states = attn.to_out[0](hidden_states)
66
+ # dropout
67
+ hidden_states = attn.to_out[1](hidden_states)
68
+
69
+ if input_ndim == 4:
70
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
71
+
72
+ if attn.residual_connection:
73
+ hidden_states = hidden_states + residual
74
+
75
+ hidden_states = hidden_states / attn.rescale_output_factor
76
+
77
+ return hidden_states
78
+
79
+
80
+ class IPAttnProcessor(nn.Module):
81
+ r"""
82
+ Attention processor for IP-Adapater.
83
+ Args:
84
+ hidden_size (`int`):
85
+ The hidden size of the attention layer.
86
+ cross_attention_dim (`int`):
87
+ The number of channels in the `encoder_hidden_states`.
88
+ scale (`float`, defaults to 1.0):
89
+ the weight scale of image prompt.
90
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
91
+ The context length of the image features.
92
+ """
93
+
94
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
95
+ super().__init__()
96
+
97
+ self.hidden_size = hidden_size
98
+ self.cross_attention_dim = cross_attention_dim
99
+ self.scale = scale
100
+ self.num_tokens = num_tokens
101
+ self.skip = skip
102
+
103
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
104
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
105
+
106
+ def __call__(
107
+ self,
108
+ attn,
109
+ hidden_states,
110
+ encoder_hidden_states=None,
111
+ attention_mask=None,
112
+ temb=None,
113
+ ):
114
+ residual = hidden_states
115
+
116
+ if attn.spatial_norm is not None:
117
+ hidden_states = attn.spatial_norm(hidden_states, temb)
118
+
119
+ input_ndim = hidden_states.ndim
120
+
121
+ if input_ndim == 4:
122
+ batch_size, channel, height, width = hidden_states.shape
123
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
124
+
125
+ batch_size, sequence_length, _ = (
126
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
127
+ )
128
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
129
+
130
+ if attn.group_norm is not None:
131
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
132
+
133
+ query = attn.to_q(hidden_states)
134
+
135
+ if encoder_hidden_states is None:
136
+ encoder_hidden_states = hidden_states
137
+ else:
138
+ # get encoder_hidden_states, ip_hidden_states
139
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
140
+ encoder_hidden_states, ip_hidden_states = (
141
+ encoder_hidden_states[:, :end_pos, :],
142
+ encoder_hidden_states[:, end_pos:, :],
143
+ )
144
+ if attn.norm_cross:
145
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
146
+
147
+ key = attn.to_k(encoder_hidden_states)
148
+ value = attn.to_v(encoder_hidden_states)
149
+
150
+ query = attn.head_to_batch_dim(query)
151
+ key = attn.head_to_batch_dim(key)
152
+ value = attn.head_to_batch_dim(value)
153
+
154
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
155
+ hidden_states = torch.bmm(attention_probs, value)
156
+ hidden_states = attn.batch_to_head_dim(hidden_states)
157
+
158
+ if not self.skip:
159
+ # for ip-adapter
160
+ ip_key = self.to_k_ip(ip_hidden_states)
161
+ ip_value = self.to_v_ip(ip_hidden_states)
162
+
163
+ ip_key = attn.head_to_batch_dim(ip_key)
164
+ ip_value = attn.head_to_batch_dim(ip_value)
165
+
166
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
167
+ self.attn_map = ip_attention_probs
168
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
169
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
170
+
171
+ hidden_states = hidden_states + self.scale * ip_hidden_states
172
+
173
+ # linear proj
174
+ hidden_states = attn.to_out[0](hidden_states)
175
+ # dropout
176
+ hidden_states = attn.to_out[1](hidden_states)
177
+
178
+ if input_ndim == 4:
179
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
180
+
181
+ if attn.residual_connection:
182
+ hidden_states = hidden_states + residual
183
+
184
+ hidden_states = hidden_states / attn.rescale_output_factor
185
+
186
+ return hidden_states
187
+
188
+
189
+ class AttnProcessor2_0(torch.nn.Module):
190
+ r"""
191
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ hidden_size=None,
197
+ cross_attention_dim=None,
198
+ ):
199
+ super().__init__()
200
+ if not hasattr(F, "scaled_dot_product_attention"):
201
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
202
+
203
+ def __call__(
204
+ self,
205
+ attn,
206
+ hidden_states,
207
+ encoder_hidden_states=None,
208
+ attention_mask=None,
209
+ temb=None,
210
+ ):
211
+ residual = hidden_states
212
+
213
+ if attn.spatial_norm is not None:
214
+ hidden_states = attn.spatial_norm(hidden_states, temb)
215
+
216
+ input_ndim = hidden_states.ndim
217
+
218
+ if input_ndim == 4:
219
+ batch_size, channel, height, width = hidden_states.shape
220
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
221
+
222
+ batch_size, sequence_length, _ = (
223
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
224
+ )
225
+
226
+ if attention_mask is not None:
227
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
228
+ # scaled_dot_product_attention expects attention_mask shape to be
229
+ # (batch, heads, source_length, target_length)
230
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
231
+
232
+ if attn.group_norm is not None:
233
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
234
+
235
+ query = attn.to_q(hidden_states)
236
+
237
+ if encoder_hidden_states is None:
238
+ encoder_hidden_states = hidden_states
239
+ elif attn.norm_cross:
240
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
241
+
242
+ key = attn.to_k(encoder_hidden_states)
243
+ value = attn.to_v(encoder_hidden_states)
244
+
245
+ inner_dim = key.shape[-1]
246
+ head_dim = inner_dim // attn.heads
247
+
248
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
+
250
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
252
+
253
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
254
+ # TODO: add support for attn.scale when we move to Torch 2.1
255
+ hidden_states = F.scaled_dot_product_attention(
256
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
257
+ )
258
+
259
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
260
+ hidden_states = hidden_states.to(query.dtype)
261
+
262
+ # linear proj
263
+ hidden_states = attn.to_out[0](hidden_states)
264
+ # dropout
265
+ hidden_states = attn.to_out[1](hidden_states)
266
+
267
+ if input_ndim == 4:
268
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
269
+
270
+ if attn.residual_connection:
271
+ hidden_states = hidden_states + residual
272
+
273
+ hidden_states = hidden_states / attn.rescale_output_factor
274
+
275
+ return hidden_states
276
+
277
+
278
+ class IPAttnProcessor2_0(torch.nn.Module):
279
+ r"""
280
+ Attention processor for IP-Adapater for PyTorch 2.0.
281
+ Args:
282
+ hidden_size (`int`):
283
+ The hidden size of the attention layer.
284
+ cross_attention_dim (`int`):
285
+ The number of channels in the `encoder_hidden_states`.
286
+ scale (`float`, defaults to 1.0):
287
+ the weight scale of image prompt.
288
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
289
+ The context length of the image features.
290
+ """
291
+
292
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False):
293
+ super().__init__()
294
+
295
+ if not hasattr(F, "scaled_dot_product_attention"):
296
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
297
+
298
+ self.hidden_size = hidden_size
299
+ self.cross_attention_dim = cross_attention_dim
300
+ self.scale = scale
301
+ self.num_tokens = num_tokens
302
+ self.skip = skip
303
+
304
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
305
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
306
+
307
+ def __call__(
308
+ self,
309
+ attn,
310
+ hidden_states,
311
+ encoder_hidden_states=None,
312
+ attention_mask=None,
313
+ temb=None,
314
+ ):
315
+ residual = hidden_states
316
+
317
+ if attn.spatial_norm is not None:
318
+ hidden_states = attn.spatial_norm(hidden_states, temb)
319
+
320
+ input_ndim = hidden_states.ndim
321
+
322
+ if input_ndim == 4:
323
+ batch_size, channel, height, width = hidden_states.shape
324
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
325
+
326
+ batch_size, sequence_length, _ = (
327
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
328
+ )
329
+
330
+ if attention_mask is not None:
331
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
332
+ # scaled_dot_product_attention expects attention_mask shape to be
333
+ # (batch, heads, source_length, target_length)
334
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
335
+
336
+ if attn.group_norm is not None:
337
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
338
+
339
+ query = attn.to_q(hidden_states)
340
+
341
+ if encoder_hidden_states is None:
342
+ encoder_hidden_states = hidden_states
343
+ else:
344
+ # get encoder_hidden_states, ip_hidden_states
345
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
346
+ encoder_hidden_states, ip_hidden_states = (
347
+ encoder_hidden_states[:, :end_pos, :],
348
+ encoder_hidden_states[:, end_pos:, :],
349
+ )
350
+ if attn.norm_cross:
351
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
352
+
353
+ key = attn.to_k(encoder_hidden_states)
354
+ value = attn.to_v(encoder_hidden_states)
355
+
356
+ inner_dim = key.shape[-1]
357
+ head_dim = inner_dim // attn.heads
358
+
359
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
360
+
361
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
362
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
363
+
364
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
365
+ # TODO: add support for attn.scale when we move to Torch 2.1
366
+ hidden_states = F.scaled_dot_product_attention(
367
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
368
+ )
369
+
370
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
371
+ hidden_states = hidden_states.to(query.dtype)
372
+
373
+ if not self.skip:
374
+ # for ip-adapter
375
+ ip_key = self.to_k_ip(ip_hidden_states)
376
+ ip_value = self.to_v_ip(ip_hidden_states)
377
+
378
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
379
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
380
+
381
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
382
+ # TODO: add support for attn.scale when we move to Torch 2.1
383
+ ip_hidden_states = F.scaled_dot_product_attention(
384
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
385
+ )
386
+ with torch.no_grad():
387
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
388
+ #print(self.attn_map.shape)
389
+
390
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
391
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
392
+
393
+ hidden_states = hidden_states + self.scale * ip_hidden_states
394
+
395
+ # linear proj
396
+ hidden_states = attn.to_out[0](hidden_states)
397
+ # dropout
398
+ hidden_states = attn.to_out[1](hidden_states)
399
+
400
+ if input_ndim == 4:
401
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
402
+
403
+ if attn.residual_connection:
404
+ hidden_states = hidden_states + residual
405
+
406
+ hidden_states = hidden_states / attn.rescale_output_factor
407
+
408
+ return hidden_states
409
+
410
+
411
+ ## for controlnet
412
+ class CNAttnProcessor:
413
+ r"""
414
+ Default processor for performing attention-related computations.
415
+ """
416
+
417
+ def __init__(self, num_tokens=4):
418
+ self.num_tokens = num_tokens
419
+
420
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
421
+ residual = hidden_states
422
+
423
+ if attn.spatial_norm is not None:
424
+ hidden_states = attn.spatial_norm(hidden_states, temb)
425
+
426
+ input_ndim = hidden_states.ndim
427
+
428
+ if input_ndim == 4:
429
+ batch_size, channel, height, width = hidden_states.shape
430
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
431
+
432
+ batch_size, sequence_length, _ = (
433
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
434
+ )
435
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
436
+
437
+ if attn.group_norm is not None:
438
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
439
+
440
+ query = attn.to_q(hidden_states)
441
+
442
+ if encoder_hidden_states is None:
443
+ encoder_hidden_states = hidden_states
444
+ else:
445
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
446
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
447
+ if attn.norm_cross:
448
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
449
+
450
+ key = attn.to_k(encoder_hidden_states)
451
+ value = attn.to_v(encoder_hidden_states)
452
+
453
+ query = attn.head_to_batch_dim(query)
454
+ key = attn.head_to_batch_dim(key)
455
+ value = attn.head_to_batch_dim(value)
456
+
457
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
458
+ hidden_states = torch.bmm(attention_probs, value)
459
+ hidden_states = attn.batch_to_head_dim(hidden_states)
460
+
461
+ # linear proj
462
+ hidden_states = attn.to_out[0](hidden_states)
463
+ # dropout
464
+ hidden_states = attn.to_out[1](hidden_states)
465
+
466
+ if input_ndim == 4:
467
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
468
+
469
+ if attn.residual_connection:
470
+ hidden_states = hidden_states + residual
471
+
472
+ hidden_states = hidden_states / attn.rescale_output_factor
473
+
474
+ return hidden_states
475
+
476
+
477
+ class CNAttnProcessor2_0:
478
+ r"""
479
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
480
+ """
481
+
482
+ def __init__(self, num_tokens=4):
483
+ if not hasattr(F, "scaled_dot_product_attention"):
484
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
485
+ self.num_tokens = num_tokens
486
+
487
+ def __call__(
488
+ self,
489
+ attn,
490
+ hidden_states,
491
+ encoder_hidden_states=None,
492
+ attention_mask=None,
493
+ temb=None,
494
+ ):
495
+ residual = hidden_states
496
+
497
+ if attn.spatial_norm is not None:
498
+ hidden_states = attn.spatial_norm(hidden_states, temb)
499
+
500
+ input_ndim = hidden_states.ndim
501
+
502
+ if input_ndim == 4:
503
+ batch_size, channel, height, width = hidden_states.shape
504
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
505
+
506
+ batch_size, sequence_length, _ = (
507
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
508
+ )
509
+
510
+ if attention_mask is not None:
511
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
512
+ # scaled_dot_product_attention expects attention_mask shape to be
513
+ # (batch, heads, source_length, target_length)
514
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
515
+
516
+ if attn.group_norm is not None:
517
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
518
+
519
+ query = attn.to_q(hidden_states)
520
+
521
+ if encoder_hidden_states is None:
522
+ encoder_hidden_states = hidden_states
523
+ else:
524
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
525
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
526
+ if attn.norm_cross:
527
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
528
+
529
+ key = attn.to_k(encoder_hidden_states)
530
+ value = attn.to_v(encoder_hidden_states)
531
+
532
+ inner_dim = key.shape[-1]
533
+ head_dim = inner_dim // attn.heads
534
+
535
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
536
+
537
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
538
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
539
+
540
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
541
+ # TODO: add support for attn.scale when we move to Torch 2.1
542
+ hidden_states = F.scaled_dot_product_attention(
543
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
544
+ )
545
+
546
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
547
+ hidden_states = hidden_states.to(query.dtype)
548
+
549
+ # linear proj
550
+ hidden_states = attn.to_out[0](hidden_states)
551
+ # dropout
552
+ hidden_states = attn.to_out[1](hidden_states)
553
+
554
+ if input_ndim == 4:
555
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
556
+
557
+ if attn.residual_connection:
558
+ hidden_states = hidden_states + residual
559
+
560
+ hidden_states = hidden_states / attn.rescale_output_factor
561
+
562
+ return hidden_states