phoebeklett
commited on
Commit
•
1748dc3
1
Parent(s):
c04fa98
Upload 2 files
Browse files- modeling.py +2 -1
modeling.py
CHANGED
@@ -356,7 +356,7 @@ class ExtendedMptAttention(nn.Module):
|
|
356 |
)
|
357 |
attn_output = self.out_proj(context_states)
|
358 |
|
359 |
-
if not output_retrieved_memory_idx:
|
360 |
reshaped_idx = None
|
361 |
|
362 |
return attn_output, attn_weights, past_key_value, reshaped_idx
|
@@ -977,6 +977,7 @@ class ExtendedMptForCausalLM(MptPreTrainedModel):
|
|
977 |
"attention_mask": attention_mask,
|
978 |
"use_external_mind": kwargs.get("use_external_mind"), # EM: Add config here
|
979 |
"topk": kwargs.get("topk"),
|
|
|
980 |
}
|
981 |
)
|
982 |
return model_inputs
|
|
|
356 |
)
|
357 |
attn_output = self.out_proj(context_states)
|
358 |
|
359 |
+
if not output_retrieved_memory_idx or (long_range_past_key_value is None and faiss_indexes is None):
|
360 |
reshaped_idx = None
|
361 |
|
362 |
return attn_output, attn_weights, past_key_value, reshaped_idx
|
|
|
977 |
"attention_mask": attention_mask,
|
978 |
"use_external_mind": kwargs.get("use_external_mind"), # EM: Add config here
|
979 |
"topk": kwargs.get("topk"),
|
980 |
+
"output_retrieved_memory_idx": kwargs.get("output_retrieved_memory_idx"),
|
981 |
}
|
982 |
)
|
983 |
return model_inputs
|