ccdv commited on
Commit
6990f99
1 Parent(s): 3aadd6d

fix encoder decoder

Browse files
Files changed (1) hide show
  1. modeling_lsg_pegasus.py +1 -1
modeling_lsg_pegasus.py CHANGED
@@ -1033,7 +1033,7 @@ class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
1033
  )
1034
 
1035
  # Pad mask if we keep globals
1036
- if self.pass_global_tokens_to_decoder:
1037
  attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
1038
 
1039
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
 
1033
  )
1034
 
1035
  # Pad mask if we keep globals
1036
+ if self.pass_global_tokens_to_decoder and attention_mask is not None:
1037
  attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
1038
 
1039
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)