SunderAli17 commited on
Commit
4212736
1 Parent(s): e55e171

Create SAK/models/ipa_faceid_plus/attention_processor.py

Browse files
SAK/models/ipa_faceid_plus/attention_processor.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class AttnProcessor2_0(torch.nn.Module):
7
+ r"""
8
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
9
+ """
10
+ def __init__(
11
+ self,
12
+ hidden_size=None,
13
+ cross_attention_dim=None,
14
+ ):
15
+ super().__init__()
16
+ if not hasattr(F, "scaled_dot_product_attention"):
17
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
18
+
19
+ def __call__(
20
+ self,
21
+ attn,
22
+ hidden_states,
23
+ encoder_hidden_states=None,
24
+ attention_mask=None,
25
+ temb=None,
26
+ ):
27
+ residual = hidden_states
28
+
29
+ if attn.spatial_norm is not None:
30
+ hidden_states = attn.spatial_norm(hidden_states, temb)
31
+
32
+ input_ndim = hidden_states.ndim
33
+
34
+ if input_ndim == 4:
35
+ batch_size, channel, height, width = hidden_states.shape
36
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37
+
38
+ batch_size, sequence_length, _ = (
39
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40
+ )
41
+
42
+ if attention_mask is not None:
43
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
44
+ # scaled_dot_product_attention expects attention_mask shape to be
45
+ # (batch, heads, source_length, target_length)
46
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
47
+
48
+ if attn.group_norm is not None:
49
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
50
+
51
+ query = attn.to_q(hidden_states)
52
+
53
+ if encoder_hidden_states is None:
54
+ encoder_hidden_states = hidden_states
55
+ elif attn.norm_cross:
56
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
57
+
58
+ key = attn.to_k(encoder_hidden_states)
59
+ value = attn.to_v(encoder_hidden_states)
60
+
61
+ inner_dim = key.shape[-1]
62
+ head_dim = inner_dim // attn.heads
63
+
64
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
65
+
66
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
67
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
68
+
69
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
70
+ # TODO: add support for attn.scale when we move to Torch 2.1
71
+ hidden_states = F.scaled_dot_product_attention(
72
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
73
+ )
74
+
75
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
76
+ hidden_states = hidden_states.to(query.dtype)
77
+
78
+ # linear proj
79
+ hidden_states = attn.to_out[0](hidden_states)
80
+ # dropout
81
+ hidden_states = attn.to_out[1](hidden_states)
82
+
83
+ if input_ndim == 4:
84
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
85
+
86
+ if attn.residual_connection:
87
+ hidden_states = hidden_states + residual
88
+
89
+ hidden_states = hidden_states / attn.rescale_output_factor
90
+
91
+ return hidden_states
92
+
93
+ class IPAttnProcessor2_0(torch.nn.Module):
94
+ r"""
95
+ Attention processor for IP-Adapater for PyTorch 2.0.
96
+ Args:
97
+ hidden_size (`int`):
98
+ The hidden size of the attention layer.
99
+ cross_attention_dim (`int`):
100
+ The number of channels in the `encoder_hidden_states`.
101
+ scale (`float`, defaults to 1.0):
102
+ the weight scale of image prompt.
103
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
104
+ The context length of the image features.
105
+ """
106
+
107
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
108
+ super().__init__()
109
+
110
+ if not hasattr(F, "scaled_dot_product_attention"):
111
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
112
+
113
+ self.hidden_size = hidden_size
114
+ self.cross_attention_dim = cross_attention_dim
115
+ self.scale = scale
116
+ self.num_tokens = num_tokens
117
+
118
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
119
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
120
+
121
+ def __call__(
122
+ self,
123
+ attn,
124
+ hidden_states,
125
+ encoder_hidden_states=None,
126
+ attention_mask=None,
127
+ temb=None,
128
+ ):
129
+ residual = hidden_states
130
+
131
+ if attn.spatial_norm is not None:
132
+ hidden_states = attn.spatial_norm(hidden_states, temb)
133
+
134
+ input_ndim = hidden_states.ndim
135
+
136
+ if input_ndim == 4:
137
+ batch_size, channel, height, width = hidden_states.shape
138
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
139
+
140
+ batch_size, sequence_length, _ = (
141
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
142
+ )
143
+
144
+ if attention_mask is not None:
145
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
146
+ # scaled_dot_product_attention expects attention_mask shape to be
147
+ # (batch, heads, source_length, target_length)
148
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
149
+
150
+ if attn.group_norm is not None:
151
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
152
+
153
+ query = attn.to_q(hidden_states)
154
+
155
+ if encoder_hidden_states is None:
156
+ encoder_hidden_states = hidden_states
157
+ else:
158
+ # get encoder_hidden_states, ip_hidden_states
159
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
160
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
161
+ if attn.norm_cross:
162
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
163
+
164
+ key = attn.to_k(encoder_hidden_states)
165
+ value = attn.to_v(encoder_hidden_states)
166
+
167
+ inner_dim = key.shape[-1]
168
+ head_dim = inner_dim // attn.heads
169
+
170
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
171
+
172
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
173
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
174
+
175
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
176
+ # TODO: add support for attn.scale when we move to Torch 2.1
177
+ hidden_states = F.scaled_dot_product_attention(
178
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
179
+ )
180
+
181
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
182
+ hidden_states = hidden_states.to(query.dtype)
183
+
184
+ # for ip-adapter
185
+ ip_key = self.to_k_ip(ip_hidden_states)
186
+ ip_value = self.to_v_ip(ip_hidden_states)
187
+
188
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
189
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
190
+
191
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
192
+ # TODO: add support for attn.scale when we move to Torch 2.1
193
+ ip_hidden_states = F.scaled_dot_product_attention(
194
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
195
+ )
196
+
197
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
198
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
199
+
200
+ hidden_states = hidden_states + self.scale * ip_hidden_states
201
+
202
+ # linear proj
203
+ hidden_states = attn.to_out[0](hidden_states)
204
+ # dropout
205
+ hidden_states = attn.to_out[1](hidden_states)
206
+
207
+ if input_ndim == 4:
208
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
209
+
210
+ if attn.residual_connection:
211
+ hidden_states = hidden_states + residual
212
+
213
+ hidden_states = hidden_states / attn.rescale_output_factor
214
+
215
+ return hidden_states