TCMVince commited on
Commit
2cfe05f
1 Parent(s): 9c29a9d

Update flaubert2_model.py

Browse files
Files changed (1) hide show
  1. flaubert2_model.py +5 -4
flaubert2_model.py CHANGED
@@ -388,11 +388,12 @@ class Flaubert2Model(RobertaModel):
388
 
389
  sequence_output = encoder_outputs[0].transpose(0,1)
390
 
391
- # Fairseq Linformer implementation works with transposed hidden states -> we transpose them back for HF implementation.
392
- hidden_states = [h.transpose(0,1) for h in encoder_outputs.hidden_states]
393
-
394
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
395
 
 
 
 
 
396
  if not return_dict:
397
  return (sequence_output, pooled_output) + encoder_outputs[1:]
398
 
@@ -400,7 +401,7 @@ class Flaubert2Model(RobertaModel):
400
  last_hidden_state=sequence_output,
401
  pooler_output=pooled_output,
402
  past_key_values=encoder_outputs.past_key_values,
403
- hidden_states=hidden_states,
404
  attentions=encoder_outputs.attentions,
405
  cross_attentions=encoder_outputs.cross_attentions,
406
  )
 
388
 
389
  sequence_output = encoder_outputs[0].transpose(0,1)
390
 
 
 
 
391
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
392
 
393
+ # Fairseq Linformer implementation works with transposed hidden states -> we transpose them back for HF implementation.
394
+ if output_hidden_states:
395
+ encoder_outputs.hidden_states = [h.transpose(0,1) for h in encoder_outputs.hidden_states]
396
+
397
  if not return_dict:
398
  return (sequence_output, pooled_output) + encoder_outputs[1:]
399
 
 
401
  last_hidden_state=sequence_output,
402
  pooler_output=pooled_output,
403
  past_key_values=encoder_outputs.past_key_values,
404
+ hidden_states=encoder_outputs.hidden_states,
405
  attentions=encoder_outputs.attentions,
406
  cross_attentions=encoder_outputs.cross_attentions,
407
  )