phoebeklett
commited on
Commit
•
c3edc15
1
Parent(s):
1df9b46
Upload 2 files
Browse files- modeling.py +2 -1
modeling.py
CHANGED
@@ -654,7 +654,7 @@ class ExtendedLlamaAttention(nn.Module):
|
|
654 |
if not output_attentions:
|
655 |
attn_weights = None
|
656 |
|
657 |
-
if not output_retrieved_memory_idx:
|
658 |
reshaped_idx = None
|
659 |
return attn_output, attn_weights, past_key_value, reshaped_idx
|
660 |
|
@@ -1568,6 +1568,7 @@ class ExtendedLlamaForCausalLM(LlamaPreTrainedModel):
|
|
1568 |
"attention_mask": attention_mask,
|
1569 |
"use_external_mind": kwargs.get("use_external_mind"), # EM: Add config here
|
1570 |
"topk": kwargs.get("topk"),
|
|
|
1571 |
}
|
1572 |
)
|
1573 |
return model_inputs
|
|
|
654 |
if not output_attentions:
|
655 |
attn_weights = None
|
656 |
|
657 |
+
if not output_retrieved_memory_idx or (long_range_past_key_value is None and faiss_indexes is None):
|
658 |
reshaped_idx = None
|
659 |
return attn_output, attn_weights, past_key_value, reshaped_idx
|
660 |
|
|
|
1568 |
"attention_mask": attention_mask,
|
1569 |
"use_external_mind": kwargs.get("use_external_mind"), # EM: Add config here
|
1570 |
"topk": kwargs.get("topk"),
|
1571 |
+
"output_retrieved_memory_idx": kwargs.get("output_retrieved_memory_idx"),
|
1572 |
}
|
1573 |
)
|
1574 |
return model_inputs
|