Richard Neuschulz commited on
Commit
f8eedcb
β€’
1 Parent(s): b883378

added files

Browse files
Files changed (2) hide show
  1. attention_processor_faceid.py +427 -0
  2. utils.py +81 -0
attention_processor_faceid.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from diffusers.models.lora import LoRALinearLayer
7
+
8
+
9
+ class LoRAAttnProcessor(nn.Module):
10
+ r"""
11
+ Default processor for performing attention-related computations.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ hidden_size=None,
17
+ cross_attention_dim=None,
18
+ rank=4,
19
+ network_alpha=None,
20
+ lora_scale=1.0,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.rank = rank
25
+ self.lora_scale = lora_scale
26
+
27
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
28
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
29
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
30
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
31
+
32
+ def __call__(
33
+ self,
34
+ attn,
35
+ hidden_states,
36
+ encoder_hidden_states=None,
37
+ attention_mask=None,
38
+ temb=None,
39
+ ):
40
+ residual = hidden_states
41
+
42
+ if attn.spatial_norm is not None:
43
+ hidden_states = attn.spatial_norm(hidden_states, temb)
44
+
45
+ input_ndim = hidden_states.ndim
46
+
47
+ if input_ndim == 4:
48
+ batch_size, channel, height, width = hidden_states.shape
49
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
50
+
51
+ batch_size, sequence_length, _ = (
52
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
53
+ )
54
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
55
+
56
+ if attn.group_norm is not None:
57
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
58
+
59
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
60
+
61
+ if encoder_hidden_states is None:
62
+ encoder_hidden_states = hidden_states
63
+ elif attn.norm_cross:
64
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
65
+
66
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
67
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
68
+
69
+ query = attn.head_to_batch_dim(query)
70
+ key = attn.head_to_batch_dim(key)
71
+ value = attn.head_to_batch_dim(value)
72
+
73
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
74
+ hidden_states = torch.bmm(attention_probs, value)
75
+ hidden_states = attn.batch_to_head_dim(hidden_states)
76
+
77
+ # linear proj
78
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
79
+ # dropout
80
+ hidden_states = attn.to_out[1](hidden_states)
81
+
82
+ if input_ndim == 4:
83
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
84
+
85
+ if attn.residual_connection:
86
+ hidden_states = hidden_states + residual
87
+
88
+ hidden_states = hidden_states / attn.rescale_output_factor
89
+
90
+ return hidden_states
91
+
92
+
93
+ class LoRAIPAttnProcessor(nn.Module):
94
+ r"""
95
+ Attention processor for IP-Adapater.
96
+ Args:
97
+ hidden_size (`int`):
98
+ The hidden size of the attention layer.
99
+ cross_attention_dim (`int`):
100
+ The number of channels in the `encoder_hidden_states`.
101
+ scale (`float`, defaults to 1.0):
102
+ the weight scale of image prompt.
103
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
104
+ The context length of the image features.
105
+ """
106
+
107
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
108
+ super().__init__()
109
+
110
+ self.rank = rank
111
+ self.lora_scale = lora_scale
112
+
113
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
114
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
115
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
116
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
117
+
118
+ self.hidden_size = hidden_size
119
+ self.cross_attention_dim = cross_attention_dim
120
+ self.scale = scale
121
+ self.num_tokens = num_tokens
122
+
123
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
124
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
125
+
126
+ def __call__(
127
+ self,
128
+ attn,
129
+ hidden_states,
130
+ encoder_hidden_states=None,
131
+ attention_mask=None,
132
+ temb=None,
133
+ ):
134
+ residual = hidden_states
135
+
136
+ if attn.spatial_norm is not None:
137
+ hidden_states = attn.spatial_norm(hidden_states, temb)
138
+
139
+ input_ndim = hidden_states.ndim
140
+
141
+ if input_ndim == 4:
142
+ batch_size, channel, height, width = hidden_states.shape
143
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
144
+
145
+ batch_size, sequence_length, _ = (
146
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
147
+ )
148
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
149
+
150
+ if attn.group_norm is not None:
151
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
152
+
153
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
154
+
155
+ if encoder_hidden_states is None:
156
+ encoder_hidden_states = hidden_states
157
+ else:
158
+ # get encoder_hidden_states, ip_hidden_states
159
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
160
+ encoder_hidden_states, ip_hidden_states = (
161
+ encoder_hidden_states[:, :end_pos, :],
162
+ encoder_hidden_states[:, end_pos:, :],
163
+ )
164
+ if attn.norm_cross:
165
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
166
+
167
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
168
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
169
+
170
+ query = attn.head_to_batch_dim(query)
171
+ key = attn.head_to_batch_dim(key)
172
+ value = attn.head_to_batch_dim(value)
173
+
174
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
175
+ hidden_states = torch.bmm(attention_probs, value)
176
+ hidden_states = attn.batch_to_head_dim(hidden_states)
177
+
178
+ # for ip-adapter
179
+ ip_key = self.to_k_ip(ip_hidden_states)
180
+ ip_value = self.to_v_ip(ip_hidden_states)
181
+
182
+ ip_key = attn.head_to_batch_dim(ip_key)
183
+ ip_value = attn.head_to_batch_dim(ip_value)
184
+
185
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
186
+ self.attn_map = ip_attention_probs
187
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
188
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
189
+
190
+ hidden_states = hidden_states + self.scale * ip_hidden_states
191
+
192
+ # linear proj
193
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
194
+ # dropout
195
+ hidden_states = attn.to_out[1](hidden_states)
196
+
197
+ if input_ndim == 4:
198
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
199
+
200
+ if attn.residual_connection:
201
+ hidden_states = hidden_states + residual
202
+
203
+ hidden_states = hidden_states / attn.rescale_output_factor
204
+
205
+ return hidden_states
206
+
207
+
208
+ class LoRAAttnProcessor2_0(nn.Module):
209
+
210
+ r"""
211
+ Default processor for performing attention-related computations.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ hidden_size=None,
217
+ cross_attention_dim=None,
218
+ rank=4,
219
+ network_alpha=None,
220
+ lora_scale=1.0,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.rank = rank
225
+ self.lora_scale = lora_scale
226
+
227
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
228
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
229
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
230
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
231
+
232
+ def __call__(
233
+ self,
234
+ attn,
235
+ hidden_states,
236
+ encoder_hidden_states=None,
237
+ attention_mask=None,
238
+ temb=None,
239
+ ):
240
+ residual = hidden_states
241
+
242
+ if attn.spatial_norm is not None:
243
+ hidden_states = attn.spatial_norm(hidden_states, temb)
244
+
245
+ input_ndim = hidden_states.ndim
246
+
247
+ if input_ndim == 4:
248
+ batch_size, channel, height, width = hidden_states.shape
249
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
250
+
251
+ batch_size, sequence_length, _ = (
252
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
253
+ )
254
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
255
+
256
+ if attn.group_norm is not None:
257
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
258
+
259
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
260
+
261
+ if encoder_hidden_states is None:
262
+ encoder_hidden_states = hidden_states
263
+ elif attn.norm_cross:
264
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
265
+
266
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
267
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
268
+
269
+ inner_dim = key.shape[-1]
270
+ head_dim = inner_dim // attn.heads
271
+
272
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
273
+
274
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
275
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
276
+
277
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
278
+ # TODO: add support for attn.scale when we move to Torch 2.1
279
+ hidden_states = F.scaled_dot_product_attention(
280
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
281
+ )
282
+
283
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
284
+ hidden_states = hidden_states.to(query.dtype)
285
+
286
+ # linear proj
287
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
288
+ # dropout
289
+ hidden_states = attn.to_out[1](hidden_states)
290
+
291
+ if input_ndim == 4:
292
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
293
+
294
+ if attn.residual_connection:
295
+ hidden_states = hidden_states + residual
296
+
297
+ hidden_states = hidden_states / attn.rescale_output_factor
298
+
299
+ return hidden_states
300
+
301
+
302
+ class LoRAIPAttnProcessor2_0(nn.Module):
303
+ r"""
304
+ Processor for implementing the LoRA attention mechanism.
305
+
306
+ Args:
307
+ hidden_size (`int`, *optional*):
308
+ The hidden size of the attention layer.
309
+ cross_attention_dim (`int`, *optional*):
310
+ The number of channels in the `encoder_hidden_states`.
311
+ rank (`int`, defaults to 4):
312
+ The dimension of the LoRA update matrices.
313
+ network_alpha (`int`, *optional*):
314
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
315
+ """
316
+
317
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
318
+ super().__init__()
319
+
320
+ self.rank = rank
321
+ self.lora_scale = lora_scale
322
+ self.num_tokens = num_tokens
323
+
324
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
325
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
326
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
327
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
328
+
329
+
330
+ self.hidden_size = hidden_size
331
+ self.cross_attention_dim = cross_attention_dim
332
+ self.scale = scale
333
+
334
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
335
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
336
+
337
+ def __call__(
338
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
339
+ ):
340
+ residual = hidden_states
341
+
342
+ if attn.spatial_norm is not None:
343
+ hidden_states = attn.spatial_norm(hidden_states, temb)
344
+
345
+ input_ndim = hidden_states.ndim
346
+
347
+ if input_ndim == 4:
348
+ batch_size, channel, height, width = hidden_states.shape
349
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
350
+
351
+ batch_size, sequence_length, _ = (
352
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
353
+ )
354
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
355
+
356
+ if attn.group_norm is not None:
357
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
358
+
359
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
360
+ #query = attn.head_to_batch_dim(query)
361
+
362
+ if encoder_hidden_states is None:
363
+ encoder_hidden_states = hidden_states
364
+ else:
365
+ # get encoder_hidden_states, ip_hidden_states
366
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
367
+ encoder_hidden_states, ip_hidden_states = (
368
+ encoder_hidden_states[:, :end_pos, :],
369
+ encoder_hidden_states[:, end_pos:, :],
370
+ )
371
+ if attn.norm_cross:
372
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
373
+
374
+ # for text
375
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
376
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
377
+
378
+ inner_dim = key.shape[-1]
379
+ head_dim = inner_dim // attn.heads
380
+
381
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
382
+
383
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
384
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
385
+
386
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
387
+ # TODO: add support for attn.scale when we move to Torch 2.1
388
+ hidden_states = F.scaled_dot_product_attention(
389
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
390
+ )
391
+
392
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
393
+ hidden_states = hidden_states.to(query.dtype)
394
+
395
+ # for ip
396
+ ip_key = self.to_k_ip(ip_hidden_states)
397
+ ip_value = self.to_v_ip(ip_hidden_states)
398
+
399
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
400
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
401
+
402
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
403
+ # TODO: add support for attn.scale when we move to Torch 2.1
404
+ ip_hidden_states = F.scaled_dot_product_attention(
405
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
406
+ )
407
+
408
+
409
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
410
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
411
+
412
+ hidden_states = hidden_states + self.scale * ip_hidden_states
413
+
414
+ # linear proj
415
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
416
+ # dropout
417
+ hidden_states = attn.to_out[1](hidden_states)
418
+
419
+ if input_ndim == 4:
420
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
421
+
422
+ if attn.residual_connection:
423
+ hidden_states = hidden_states + residual
424
+
425
+ hidden_states = hidden_states / attn.rescale_output_factor
426
+
427
+ return hidden_states
utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ attn_maps = {}
7
+ def hook_fn(name):
8
+ def forward_hook(module, input, output):
9
+ if hasattr(module.processor, "attn_map"):
10
+ attn_maps[name] = module.processor.attn_map
11
+ del module.processor.attn_map
12
+
13
+ return forward_hook
14
+
15
+ def register_cross_attention_hook(unet):
16
+ for name, module in unet.named_modules():
17
+ if name.split('.')[-1].startswith('attn2'):
18
+ module.register_forward_hook(hook_fn(name))
19
+
20
+ return unet
21
+
22
+ def upscale(attn_map, target_size):
23
+ attn_map = torch.mean(attn_map, dim=0)
24
+ attn_map = attn_map.permute(1,0)
25
+ temp_size = None
26
+
27
+ for i in range(0,5):
28
+ scale = 2 ** i
29
+ if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
30
+ temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
31
+ break
32
+
33
+ assert temp_size is not None, "temp_size cannot is None"
34
+
35
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36
+
37
+ attn_map = F.interpolate(
38
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
39
+ size=target_size,
40
+ mode='bilinear',
41
+ align_corners=False
42
+ )[0]
43
+
44
+ attn_map = torch.softmax(attn_map, dim=0)
45
+ return attn_map
46
+ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
47
+
48
+ idx = 0 if instance_or_negative else 1
49
+ net_attn_maps = []
50
+
51
+ for name, attn_map in attn_maps.items():
52
+ attn_map = attn_map.cpu() if detach else attn_map
53
+ attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
54
+ attn_map = upscale(attn_map, image_size)
55
+ net_attn_maps.append(attn_map)
56
+
57
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
58
+
59
+ return net_attn_maps
60
+
61
+ def attnmaps2images(net_attn_maps):
62
+
63
+ #total_attn_scores = 0
64
+ images = []
65
+
66
+ for attn_map in net_attn_maps:
67
+ attn_map = attn_map.cpu().numpy()
68
+ #total_attn_scores += attn_map.mean().item()
69
+
70
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
71
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
72
+ #print("norm: ", normalized_attn_map.shape)
73
+ image = Image.fromarray(normalized_attn_map)
74
+
75
+ #image = fix_save_attn_map(attn_map)
76
+ images.append(image)
77
+
78
+ #print(total_attn_scores)
79
+ return images
80
+ def is_torch2_available():
81
+ return hasattr(F, "scaled_dot_product_attention")