ydshieh
commited on
Commit
•
58a121d
1
Parent(s):
3ed2a5d
change to add_cross_attention
Browse files
vit_gpt2/modeling_flax_gpt2.py
CHANGED
@@ -299,7 +299,7 @@ class FlaxGPT2Block(nn.Module):
|
|
299 |
|
300 |
def setup(self):
|
301 |
|
302 |
-
self.only_self_attn = not self.config.
|
303 |
|
304 |
hidden_size = self.config.hidden_size
|
305 |
inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
|
@@ -412,7 +412,7 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
|
412 |
params_rng, dropout_rng = jax.random.split(rng)
|
413 |
rngs = {"params": params_rng, "dropout": dropout_rng}
|
414 |
|
415 |
-
if self.config.
|
416 |
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
|
417 |
encoder_attention_mask = attention_mask
|
418 |
module_init_outputs = self.module.init(
|
|
|
299 |
|
300 |
def setup(self):
|
301 |
|
302 |
+
self.only_self_attn = not self.config.add_cross_attention
|
303 |
|
304 |
hidden_size = self.config.hidden_size
|
305 |
inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
|
|
|
412 |
params_rng, dropout_rng = jax.random.split(rng)
|
413 |
rngs = {"params": params_rng, "dropout": dropout_rng}
|
414 |
|
415 |
+
if self.config.add_cross_attention:
|
416 |
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
|
417 |
encoder_attention_mask = attention_mask
|
418 |
module_init_outputs = self.module.init(
|