Raghavan commited on
Commit
121271e
1 Parent(s): f6f7bf9

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +2 -2
modeling_indictrans.py CHANGED
@@ -689,7 +689,7 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
689
  if self.layernorm_embedding is not None:
690
  x = self.layernorm_embedding(hidden_states)
691
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
692
-
693
  # expand attention_mask
694
  if attention_mask is not None:
695
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -754,7 +754,7 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
754
  if output_hidden_states:
755
  encoder_states = encoder_states + (hidden_states,)
756
 
757
- hidden_states = self.get_pooled_representation(hidden_states, attention_mask)
758
 
759
  if not return_dict:
760
  return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
 
689
  if self.layernorm_embedding is not None:
690
  x = self.layernorm_embedding(hidden_states)
691
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
692
+ original_attention_mask = attention_mask.clone()
693
  # expand attention_mask
694
  if attention_mask is not None:
695
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 
754
  if output_hidden_states:
755
  encoder_states = encoder_states + (hidden_states,)
756
 
757
+ hidden_states = self.get_pooled_representation(hidden_states, original_attention_mask)
758
 
759
  if not return_dict:
760
  return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)