fix encoder decoder
Browse files- modeling_lsg_bart.py +1 -1
modeling_lsg_bart.py
CHANGED
@@ -998,7 +998,7 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
|
998 |
)
|
999 |
|
1000 |
# Pad mask for global tokens
|
1001 |
-
if self.pass_global_tokens_to_decoder:
|
1002 |
attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
|
1003 |
|
1004 |
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
|
|
998 |
)
|
999 |
|
1000 |
# Pad mask for global tokens
|
1001 |
+
if self.pass_global_tokens_to_decoder and attention_mask is not None:
|
1002 |
attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
|
1003 |
|
1004 |
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|