ydshieh commited on
Commit
bfef308
1 Parent(s): 0817d00

fix FlaxViTGPT2LMModule return value

Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py CHANGED
@@ -78,8 +78,8 @@ class FlaxViTGPT2LMModule(nn.Module):
78
 
79
  return FlaxSeq2SeqLMOutput(
80
  logits=decoder_outputs.logits,
81
- decoder_hidden_states=decoder_outputs.decoder_hidden_states,
82
- decoder_attentions=decoder_outputs.decoder_attentions,
83
  cross_attentions=decoder_outputs.cross_attentions,
84
  encoder_last_hidden_state=encoder_outputs.last_hidden_state,
85
  encoder_hidden_states=encoder_outputs.hidden_states,
 
78
 
79
  return FlaxSeq2SeqLMOutput(
80
  logits=decoder_outputs.logits,
81
+ decoder_hidden_states=decoder_outputs.hidden_states,
82
+ decoder_attentions=decoder_outputs.attentions,
83
  cross_attentions=decoder_outputs.cross_attentions,
84
  encoder_last_hidden_state=encoder_outputs.last_hidden_state,
85
  encoder_hidden_states=encoder_outputs.hidden_states,