ydshieh commited on
Commit
64afcd5
1 Parent(s): ec3ceb6
Files changed (1) hide show
  1. vit_gpt2/modeling_flax_gpt2.py +27 -11
vit_gpt2/modeling_flax_gpt2.py CHANGED
@@ -24,7 +24,10 @@ from flax.linen.attention import dot_product_attention_weights
24
  from jax import lax
25
 
26
  from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
27
- from ...modeling_flax_outputs import FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxCausalLMOutputWithCrossAttentions
 
 
 
28
  from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
29
  from ...utils import logging
30
  from .configuration_gpt2 import GPT2Config
@@ -301,7 +304,9 @@ class FlaxGPT2Block(nn.Module):
301
  self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
302
 
303
  if self.config.add_cross_attention:
304
- self.crossattention = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, is_cross_attention=True)
 
 
305
  self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
306
 
307
  project_encoder = getattr(self.config, "project_encoder", None)
@@ -337,7 +342,6 @@ class FlaxGPT2Block(nn.Module):
337
  hidden_states = attn_output + residual
338
 
339
  # Cross-Attention Block
340
- cross_attn_weights = None
341
  if encoder_hidden_states is not None:
342
  # add one self-attention block for cross-attention
343
  if not hasattr(self, "crossattention"):
@@ -413,13 +417,16 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
413
  encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
414
  encoder_attention_mask = attention_mask
415
  module_init_outputs = self.module.init(
416
- rngs, input_ids, attention_mask, position_ids,
417
- encoder_hidden_states, encoder_attention_mask, return_dict=False
 
 
 
 
 
418
  )
419
  else:
420
- module_init_outputs = self.module.init(
421
- rngs, input_ids, attention_mask, position_ids, return_dict=False
422
- )
423
 
424
  return module_init_outputs["params"]
425
 
@@ -660,7 +667,11 @@ class FlaxGPT2Model(FlaxGPT2PreTrainedModel):
660
 
661
 
662
  append_call_sample_docstring(
663
- FlaxGPT2Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPastAndCrossAttentions, _CONFIG_FOR_DOC
 
 
 
 
664
  )
665
 
666
 
@@ -718,9 +729,10 @@ class FlaxGPT2LMHeadModule(nn.Module):
718
  logits=lm_logits,
719
  hidden_states=outputs.hidden_states,
720
  attentions=outputs.attentions,
721
- cross_attentions=outputs.cross_attentions
722
  )
723
 
 
724
  @add_start_docstrings(
725
  """
726
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
@@ -759,5 +771,9 @@ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
759
 
760
 
761
  append_call_sample_docstring(
762
- FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutputWithCrossAttentions, _CONFIG_FOR_DOC
 
 
 
 
763
  )
24
  from jax import lax
25
 
26
  from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
27
+ from ...modeling_flax_outputs import (
28
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
29
+ FlaxCausalLMOutputWithCrossAttentions,
30
+ )
31
  from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
32
  from ...utils import logging
33
  from .configuration_gpt2 import GPT2Config
304
  self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
305
 
306
  if self.config.add_cross_attention:
307
+ self.crossattention = FlaxGPT2Attention(
308
+ config=self.config, dtype=self.dtype, causal=False, is_cross_attention=True
309
+ )
310
  self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
311
 
312
  project_encoder = getattr(self.config, "project_encoder", None)
342
  hidden_states = attn_output + residual
343
 
344
  # Cross-Attention Block
 
345
  if encoder_hidden_states is not None:
346
  # add one self-attention block for cross-attention
347
  if not hasattr(self, "crossattention"):
417
  encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
418
  encoder_attention_mask = attention_mask
419
  module_init_outputs = self.module.init(
420
+ rngs,
421
+ input_ids,
422
+ attention_mask,
423
+ position_ids,
424
+ encoder_hidden_states,
425
+ encoder_attention_mask,
426
+ return_dict=False,
427
  )
428
  else:
429
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
 
 
430
 
431
  return module_init_outputs["params"]
432
 
667
 
668
 
669
  append_call_sample_docstring(
670
+ FlaxGPT2Model,
671
+ _TOKENIZER_FOR_DOC,
672
+ _CHECKPOINT_FOR_DOC,
673
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
674
+ _CONFIG_FOR_DOC,
675
  )
676
 
677
 
729
  logits=lm_logits,
730
  hidden_states=outputs.hidden_states,
731
  attentions=outputs.attentions,
732
+ cross_attentions=outputs.cross_attentions,
733
  )
734
 
735
+
736
  @add_start_docstrings(
737
  """
738
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
771
 
772
 
773
  append_call_sample_docstring(
774
+ FlaxGPT2LMHeadModel,
775
+ _TOKENIZER_FOR_DOC,
776
+ _CHECKPOINT_FOR_DOC,
777
+ FlaxCausalLMOutputWithCrossAttentions,
778
+ _CONFIG_FOR_DOC,
779
  )