ydshieh
commited on
Commit
•
a01b02a
1
Parent(s):
f082d66
remove only_self_attn
Browse files
vit_gpt2/modeling_flax_gpt2.py
CHANGED
@@ -299,15 +299,13 @@ class FlaxGPT2Block(nn.Module):
|
|
299 |
|
300 |
def setup(self):
|
301 |
|
302 |
-
self.only_self_attn = not self.config.add_cross_attention
|
303 |
-
|
304 |
hidden_size = self.config.hidden_size
|
305 |
inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
|
306 |
|
307 |
self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
308 |
self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
|
309 |
|
310 |
-
if
|
311 |
self.cross_attn_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
312 |
# [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
|
313 |
self.cross_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
|
@@ -343,16 +341,17 @@ class FlaxGPT2Block(nn.Module):
|
|
343 |
attn_output = outputs[0]
|
344 |
hidden_states = attn_output + residual
|
345 |
|
346 |
-
# sanity check
|
347 |
-
if not self.only_self_attn:
|
348 |
-
assert encoder_hidden_states is not None
|
349 |
-
else:
|
350 |
-
assert encoder_hidden_states is None
|
351 |
-
|
352 |
# Cross-Attention Block
|
353 |
cross_attn_weights = None
|
354 |
if encoder_hidden_states is not None:
|
355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
project_encoder = getattr(self.config, "project_encoder", None)
|
357 |
if project_encoder:
|
358 |
residual = encoder_hidden_states
|
@@ -393,7 +392,7 @@ class FlaxGPT2Block(nn.Module):
|
|
393 |
if output_attentions:
|
394 |
self_attn_weights = attn_output[1]
|
395 |
outputs += (self_attn_weights,)
|
396 |
-
if not
|
397 |
outputs += (cross_attn_weights,)
|
398 |
|
399 |
return outputs
|
|
|
299 |
|
300 |
def setup(self):
|
301 |
|
|
|
|
|
302 |
hidden_size = self.config.hidden_size
|
303 |
inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
|
304 |
|
305 |
self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
306 |
self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
|
307 |
|
308 |
+
if self.config.add_cross_attention:
|
309 |
self.cross_attn_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
310 |
# [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
|
311 |
self.cross_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
|
|
|
341 |
attn_output = outputs[0]
|
342 |
hidden_states = attn_output + residual
|
343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
# Cross-Attention Block
|
345 |
cross_attn_weights = None
|
346 |
if encoder_hidden_states is not None:
|
347 |
|
348 |
+
# add one self-attention block for cross-attention
|
349 |
+
if not hasattr(self, "cross_attn"):
|
350 |
+
raise ValueError(
|
351 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
352 |
+
"cross-attention layers by setting `config.add_cross_attention=True`"
|
353 |
+
)
|
354 |
+
|
355 |
project_encoder = getattr(self.config, "project_encoder", None)
|
356 |
if project_encoder:
|
357 |
residual = encoder_hidden_states
|
|
|
392 |
if output_attentions:
|
393 |
self_attn_weights = attn_output[1]
|
394 |
outputs += (self_attn_weights,)
|
395 |
+
if cross_attn_weights is not None:
|
396 |
outputs += (cross_attn_weights,)
|
397 |
|
398 |
return outputs
|