ydshieh commited on
Commit
165ad1e
1 Parent(s): a01b02a

Change Flax GPT2 with cross-attn outputs to be the same as PyTorch's version

Browse files
Files changed (1) hide show
  1. vit_gpt2/modeling_flax_gpt2.py +24 -40
vit_gpt2/modeling_flax_gpt2.py CHANGED
@@ -593,28 +593,21 @@ class FlaxGPT2BlockCollection(nn.Module):
593
  if output_hidden_states:
594
  all_hidden_states += (hidden_states,)
595
 
 
596
  outputs = [hidden_states, all_hidden_states, all_attentions, all_cross_attentions]
597
 
598
  if not return_dict:
599
  return tuple(v for v in outputs if v is not None)
600
 
601
- if encoder_hidden_states is None:
602
- # only self_attn
603
- return FlaxBaseModelOutputWithPast(
604
- last_hidden_state=hidden_states,
605
- past_key_values=None,
606
- hidden_states=all_hidden_states,
607
- attentions=all_attentions,
608
- )
609
- else:
610
- # with cross_attn
611
- return FlaxBaseModelOutputWithPastAndCrossAttentions(
612
- last_hidden_state=hidden_states,
613
- past_key_values=None,
614
- hidden_states=all_hidden_states,
615
- attentions=all_attentions,
616
- cross_attentions=all_cross_attentions,
617
- )
618
 
619
  class FlaxGPT2Module(nn.Module):
620
  config: GPT2Config
@@ -676,19 +669,13 @@ class FlaxGPT2Module(nn.Module):
676
  if not return_dict:
677
  return (hidden_states,) + outputs[1:]
678
 
679
- if encoder_hidden_states is None:
680
- return FlaxBaseModelOutput(
681
- last_hidden_state=hidden_states,
682
- hidden_states=outputs.hidden_states,
683
- attentions=outputs.attentions,
684
- )
685
- else:
686
- return FlaxBaseModelOutputWithPastAndCrossAttentions(
687
- last_hidden_state=hidden_states,
688
- hidden_states=outputs.hidden_states,
689
- attentions=outputs.attentions,
690
- cross_attentions=outputs.cross_attentions,
691
- )
692
 
693
  @add_start_docstrings(
694
  "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
@@ -753,16 +740,13 @@ class FlaxGPT2LMHeadModule(nn.Module):
753
  if not return_dict:
754
  return (lm_logits,) + outputs[1:]
755
 
756
- if encoder_hidden_states is None:
757
- return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
758
- else:
759
- return FlaxCausalLMOutputWithCrossAttentions(
760
- logits=lm_logits,
761
- past_key_values=None,
762
- hidden_states=outputs.hidden_states,
763
- attentions=outputs.attentions,
764
- cross_attentions=outputs.cross_attentions
765
- )
766
 
767
  @add_start_docstrings(
768
  """
 
593
  if output_hidden_states:
594
  all_hidden_states += (hidden_states,)
595
 
596
+ # In Flax, `past_key_values` is not contained in modules' outputs.
597
  outputs = [hidden_states, all_hidden_states, all_attentions, all_cross_attentions]
598
 
599
  if not return_dict:
600
  return tuple(v for v in outputs if v is not None)
601
 
602
+ # with cross_attn
603
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
604
+ last_hidden_state=hidden_states,
605
+ past_key_values=None,
606
+ hidden_states=all_hidden_states,
607
+ attentions=all_attentions,
608
+ cross_attentions=all_cross_attentions,
609
+ )
610
+
 
 
 
 
 
 
 
 
611
 
612
  class FlaxGPT2Module(nn.Module):
613
  config: GPT2Config
 
669
  if not return_dict:
670
  return (hidden_states,) + outputs[1:]
671
 
672
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
673
+ last_hidden_state=hidden_states,
674
+ hidden_states=outputs.hidden_states,
675
+ attentions=outputs.attentions,
676
+ cross_attentions=outputs.cross_attentions,
677
+ )
678
+
 
 
 
 
 
 
679
 
680
  @add_start_docstrings(
681
  "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
 
740
  if not return_dict:
741
  return (lm_logits,) + outputs[1:]
742
 
743
+ return FlaxCausalLMOutputWithCrossAttentions(
744
+ logits=lm_logits,
745
+ past_key_values=None,
746
+ hidden_states=outputs.hidden_states,
747
+ attentions=outputs.attentions,
748
+ cross_attentions=outputs.cross_attentions
749
+ )
 
 
 
750
 
751
  @add_start_docstrings(
752
  """