Update modeling_llama.py
Browse files- modeling_llama.py +8 -8
modeling_llama.py
CHANGED
@@ -646,14 +646,14 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
646 |
if inputs_embeds is None:
|
647 |
inputs_embeds = self.embed_tokens(input_ids)
|
648 |
# embed positions
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
attention_mask = None
|
657 |
|
658 |
|
659 |
hidden_states = inputs_embeds
|
|
|
646 |
if inputs_embeds is None:
|
647 |
inputs_embeds = self.embed_tokens(input_ids)
|
648 |
# embed positions
|
649 |
+
if attention_mask is None:
|
650 |
+
attention_mask = torch.ones(
|
651 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
652 |
+
)
|
653 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
654 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
655 |
+
)
|
656 |
+
# attention_mask = None
|
657 |
|
658 |
|
659 |
hidden_states = inputs_embeds
|