revi13 commited on
Commit
a44d558
·
verified ·
1 Parent(s): a37db24

Create attention_processor_faceid.py

Browse files
ip_adapter/attention_processor_faceid.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from diffusers.models.lora import LoRALinearLayer
8
+
9
+
10
+ class LoRAAttnProcessor(nn.Module):
11
+ r"""
12
+ Default processor for performing attention-related computations.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ hidden_size=None,
18
+ cross_attention_dim=None,
19
+ rank=4,
20
+ network_alpha=None,
21
+ lora_scale=1.0,
22
+ ):
23
+ super().__init__()
24
+
25
+ self.rank = rank
26
+ self.lora_scale = lora_scale
27
+
28
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
29
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
30
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
31
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
32
+
33
+ def __call__(
34
+ self,
35
+ attn,
36
+ hidden_states,
37
+ encoder_hidden_states=None,
38
+ attention_mask=None,
39
+ temb=None,
40
+ *args,
41
+ **kwargs,
42
+ ):
43
+ residual = hidden_states
44
+
45
+ if attn.spatial_norm is not None:
46
+ hidden_states = attn.spatial_norm(hidden_states, temb)
47
+
48
+ input_ndim = hidden_states.ndim
49
+
50
+ if input_ndim == 4:
51
+ batch_size, channel, height, width = hidden_states.shape
52
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
53
+
54
+ batch_size, sequence_length, _ = (
55
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
56
+ )
57
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
58
+
59
+ if attn.group_norm is not None:
60
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
61
+
62
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
63
+
64
+ if encoder_hidden_states is None:
65
+ encoder_hidden_states = hidden_states
66
+ elif attn.norm_cross:
67
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
68
+
69
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
70
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
71
+
72
+ query = attn.head_to_batch_dim(query)
73
+ key = attn.head_to_batch_dim(key)
74
+ value = attn.head_to_batch_dim(value)
75
+
76
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
77
+ hidden_states = torch.bmm(attention_probs, value)
78
+ hidden_states = attn.batch_to_head_dim(hidden_states)
79
+
80
+ # linear proj
81
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
82
+ # dropout
83
+ hidden_states = attn.to_out[1](hidden_states)
84
+
85
+ if input_ndim == 4:
86
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
87
+
88
+ if attn.residual_connection:
89
+ hidden_states = hidden_states + residual
90
+
91
+ hidden_states = hidden_states / attn.rescale_output_factor
92
+
93
+ return hidden_states
94
+
95
+
96
+ class LoRAIPAttnProcessor(nn.Module):
97
+ r"""
98
+ Attention processor for IP-Adapater.
99
+ Args:
100
+ hidden_size (`int`):
101
+ The hidden size of the attention layer.
102
+ cross_attention_dim (`int`):
103
+ The number of channels in the `encoder_hidden_states`.
104
+ scale (`float`, defaults to 1.0):
105
+ the weight scale of image prompt.
106
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
107
+ The context length of the image features.
108
+ """
109
+
110
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
111
+ super().__init__()
112
+
113
+ self.rank = rank
114
+ self.lora_scale = lora_scale
115
+
116
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
117
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
118
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
119
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
120
+
121
+ self.hidden_size = hidden_size
122
+ self.cross_attention_dim = cross_attention_dim
123
+ self.scale = scale
124
+ self.num_tokens = num_tokens
125
+
126
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
127
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
128
+
129
+ def __call__(
130
+ self,
131
+ attn,
132
+ hidden_states,
133
+ encoder_hidden_states=None,
134
+ attention_mask=None,
135
+ temb=None,
136
+ *args,
137
+ **kwargs,
138
+ ):
139
+ residual = hidden_states
140
+
141
+ if attn.spatial_norm is not None:
142
+ hidden_states = attn.spatial_norm(hidden_states, temb)
143
+
144
+ input_ndim = hidden_states.ndim
145
+
146
+ if input_ndim == 4:
147
+ batch_size, channel, height, width = hidden_states.shape
148
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
149
+
150
+ batch_size, sequence_length, _ = (
151
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
152
+ )
153
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
154
+
155
+ if attn.group_norm is not None:
156
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
157
+
158
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
159
+
160
+ if encoder_hidden_states is None:
161
+ encoder_hidden_states = hidden_states
162
+ else:
163
+ # get encoder_hidden_states, ip_hidden_states
164
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
165
+ encoder_hidden_states, ip_hidden_states = (
166
+ encoder_hidden_states[:, :end_pos, :],
167
+ encoder_hidden_states[:, end_pos:, :],
168
+ )
169
+ if attn.norm_cross:
170
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
171
+
172
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
173
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
174
+
175
+ query = attn.head_to_batch_dim(query)
176
+ key = attn.head_to_batch_dim(key)
177
+ value = attn.head_to_batch_dim(value)
178
+
179
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
180
+ hidden_states = torch.bmm(attention_probs, value)
181
+ hidden_states = attn.batch_to_head_dim(hidden_states)
182
+
183
+ # for ip-adapter
184
+ ip_key = self.to_k_ip(ip_hidden_states)
185
+ ip_value = self.to_v_ip(ip_hidden_states)
186
+
187
+ ip_key = attn.head_to_batch_dim(ip_key)
188
+ ip_value = attn.head_to_batch_dim(ip_value)
189
+
190
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
191
+ self.attn_map = ip_attention_probs
192
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
193
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
194
+
195
+ hidden_states = hidden_states + self.scale * ip_hidden_states
196
+
197
+ # linear proj
198
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
199
+ # dropout
200
+ hidden_states = attn.to_out[1](hidden_states)
201
+
202
+ if input_ndim == 4:
203
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
204
+
205
+ if attn.residual_connection:
206
+ hidden_states = hidden_states + residual
207
+
208
+ hidden_states = hidden_states / attn.rescale_output_factor
209
+
210
+ return hidden_states
211
+
212
+
213
+ class LoRAAttnProcessor2_0(nn.Module):
214
+
215
+ r"""
216
+ Default processor for performing attention-related computations.
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ hidden_size=None,
222
+ cross_attention_dim=None,
223
+ rank=4,
224
+ network_alpha=None,
225
+ lora_scale=1.0,
226
+ ):
227
+ super().__init__()
228
+
229
+ self.rank = rank
230
+ self.lora_scale = lora_scale
231
+
232
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
233
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
234
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
235
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
236
+
237
+ def __call__(
238
+ self,
239
+ attn,
240
+ hidden_states,
241
+ encoder_hidden_states=None,
242
+ attention_mask=None,
243
+ temb=None,
244
+ *args,
245
+ **kwargs,
246
+ ):
247
+ residual = hidden_states
248
+
249
+ if attn.spatial_norm is not None:
250
+ hidden_states = attn.spatial_norm(hidden_states, temb)
251
+
252
+ input_ndim = hidden_states.ndim
253
+
254
+ if input_ndim == 4:
255
+ batch_size, channel, height, width = hidden_states.shape
256
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
257
+
258
+ batch_size, sequence_length, _ = (
259
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
260
+ )
261
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
262
+
263
+ if attn.group_norm is not None:
264
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
265
+
266
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
267
+
268
+ if encoder_hidden_states is None:
269
+ encoder_hidden_states = hidden_states
270
+ elif attn.norm_cross:
271
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
272
+
273
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
274
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
275
+
276
+ inner_dim = key.shape[-1]
277
+ head_dim = inner_dim // attn.heads
278
+
279
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
280
+
281
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
282
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
283
+
284
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
285
+ # TODO: add support for attn.scale when we move to Torch 2.1
286
+ hidden_states = F.scaled_dot_product_attention(
287
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
288
+ )
289
+
290
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
291
+ hidden_states = hidden_states.to(query.dtype)
292
+
293
+ # linear proj
294
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
295
+ # dropout
296
+ hidden_states = attn.to_out[1](hidden_states)
297
+
298
+ if input_ndim == 4:
299
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
300
+
301
+ if attn.residual_connection:
302
+ hidden_states = hidden_states + residual
303
+
304
+ hidden_states = hidden_states / attn.rescale_output_factor
305
+
306
+ return hidden_states
307
+
308
+
309
+ class LoRAIPAttnProcessor2_0(nn.Module):
310
+ r"""
311
+ Processor for implementing the LoRA attention mechanism.
312
+
313
+ Args:
314
+ hidden_size (`int`, *optional*):
315
+ The hidden size of the attention layer.
316
+ cross_attention_dim (`int`, *optional*):
317
+ The number of channels in the `encoder_hidden_states`.
318
+ rank (`int`, defaults to 4):
319
+ The dimension of the LoRA update matrices.
320
+ network_alpha (`int`, *optional*):
321
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
322
+ """
323
+
324
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
325
+ super().__init__()
326
+
327
+ self.rank = rank
328
+ self.lora_scale = lora_scale
329
+ self.num_tokens = num_tokens
330
+
331
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
332
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
333
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
334
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
335
+
336
+
337
+ self.hidden_size = hidden_size
338
+ self.cross_attention_dim = cross_attention_dim
339
+ self.scale = scale
340
+
341
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
342
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
343
+
344
+ def __call__(
345
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None, *args, **kwargs,
346
+ ):
347
+ residual = hidden_states
348
+
349
+ if attn.spatial_norm is not None:
350
+ hidden_states = attn.spatial_norm(hidden_states, temb)
351
+
352
+ input_ndim = hidden_states.ndim
353
+
354
+ if input_ndim == 4:
355
+ batch_size, channel, height, width = hidden_states.shape
356
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
357
+
358
+ batch_size, sequence_length, _ = (
359
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
360
+ )
361
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
362
+
363
+ if attn.group_norm is not None:
364
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
365
+
366
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
367
+ #query = attn.head_to_batch_dim(query)
368
+
369
+ if encoder_hidden_states is None:
370
+ encoder_hidden_states = hidden_states
371
+ else:
372
+ # get encoder_hidden_states, ip_hidden_states
373
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
374
+ encoder_hidden_states, ip_hidden_states = (
375
+ encoder_hidden_states[:, :end_pos, :],
376
+ encoder_hidden_states[:, end_pos:, :],
377
+ )
378
+ if attn.norm_cross:
379
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
380
+
381
+ # for text
382
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
383
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
384
+
385
+ inner_dim = key.shape[-1]
386
+ head_dim = inner_dim // attn.heads
387
+
388
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
389
+
390
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
391
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
392
+
393
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
394
+ # TODO: add support for attn.scale when we move to Torch 2.1
395
+ hidden_states = F.scaled_dot_product_attention(
396
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
397
+ )
398
+
399
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
400
+ hidden_states = hidden_states.to(query.dtype)
401
+
402
+ # for ip
403
+ ip_key = self.to_k_ip(ip_hidden_states)
404
+ ip_value = self.to_v_ip(ip_hidden_states)
405
+
406
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
407
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
408
+
409
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
410
+ # TODO: add support for attn.scale when we move to Torch 2.1
411
+ ip_hidden_states = F.scaled_dot_product_attention(
412
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
413
+ )
414
+
415
+
416
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
417
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
418
+
419
+ hidden_states = hidden_states + self.scale * ip_hidden_states
420
+
421
+ # linear proj
422
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
423
+ # dropout
424
+ hidden_states = attn.to_out[1](hidden_states)
425
+
426
+ if input_ndim == 4:
427
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
428
+
429
+ if attn.residual_connection:
430
+ hidden_states = hidden_states + residual
431
+
432
+ hidden_states = hidden_states / attn.rescale_output_factor
433
+
434
+ return hidden_states