ccdv commited on
Commit
416eafe
1 Parent(s): 34f38d8

fix encoder decoder

Browse files
Files changed (1) hide show
  1. modeling_lsg_bart.py +1 -1
modeling_lsg_bart.py CHANGED
@@ -998,7 +998,7 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
998
  )
999
 
1000
  # Pad mask for global tokens
1001
- if self.pass_global_tokens_to_decoder:
1002
  attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
1003
 
1004
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
998
  )
999
 
1000
  # Pad mask for global tokens
1001
+ if self.pass_global_tokens_to_decoder and attention_mask is not None:
1002
  attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
1003
 
1004
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)