TMCRA-agent-memory-algorithm / code /memory_adapters.py
2009YU's picture
Add files using upload-large-folder tool
3a64edb verified
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
import json
import math
import os
from pathlib import Path
import re
import tempfile
import time
import uuid
from typing import Any, Dict, Iterable, List, Mapping, Sequence
from core.session_memory import SessionMemoryExtractor
from experiments.replacement.memory_graph import (
SessionMemoryGraphV2,
SessionMemoryEdgeV2,
SessionMemoryRecordV2,
SQLiteSessionMemoryStore,
_clean_text,
_estimate_tokens,
_normalize,
_public_query_subject,
_public_slot_root,
_public_subject_signature,
stable_slot_key,
_tokenize,
guess_slot_key,
infer_category_hints,
)
from experiments.replacement.node_memory import (
_call_with_supported_kwargs,
DEFAULT_MATRIX_EVENT_TOP_K,
LoadedNodeMemoryScorer,
MEMORY_ROUTER_LAYERS,
build_default_path_templates,
extract_question_features,
)
from experiments.replacement.public_event_signature import compute_public_event_signature
from experiments.replacement.memory_profiles import TMCRAProfile
from experiments.replacement.typed_tunnel_augmentation import (
annotate_memory_record,
merge_typed_metadata,
typed_edge_tags_between,
typed_tunnel_signature_text,
)
from experiments.replacement.profile_layer import infer_profile_query_intent, is_profile_layer_record, profile_query_score_delta
from experiments.replacement import injection_planner as injection_planner_runtime
from experiments.replacement.temporal_modeling_types import TemporalFrame, TemporalQueryPlan
from experiments.replacement.temporal_organizer import TemporalOrganizer
from experiments.replacement.temporal_query_planner import TemporalQueryPlanner
from experiments.replacement.temporal_router_runtime import LoadedTemporalRouter
from experiments.replacement.timeline_evidence_pack import TimelineEvidencePackBuilder
from experiments.replacement.timeline_state_layer import TimelineStateLayer
from .base import MemoryAdapter, MemoryHit, MemoryRetrieval
def _dedupe(items: Iterable[Any], *, max_items: int | None = None) -> List[str]:
values: List[str] = []
seen = set()
for item in items:
text = _clean_text(item)
if not text:
continue
key = _normalize(text)
if key in seen:
continue
seen.add(key)
values.append(text)
if max_items is not None and len(values) >= max_items:
break
return values
def _apply_typed_tunnel_annotations(records: List[SessionMemoryRecordV2], *, source_text: str = "") -> List[SessionMemoryRecordV2]:
for record in records:
annotate_memory_record(record, source_text=source_text)
return records
def _estimate_tokens_from_hits(hits: Sequence[MemoryHit]) -> int:
total = 0
for hit in hits:
total += _estimate_tokens(hit.value)
total += sum(_estimate_tokens(anchor) for anchor in hit.anchors)
return total
def _float_env(name: str, default: float) -> float:
raw = _clean_text(os.getenv(name, ""))
if not raw:
return float(default)
try:
return float(raw)
except ValueError:
return float(default)
_GRAPH_PROMPT_MAX_CHARS = 12_000
_GRAPH_PROMPT_MAX_HITS = 8
_GRAPH_PROMPT_MAX_ACTIVE_SLOTS = 12
_GRAPH_PROMPT_MAX_RELATIONS = 10
_HYBRID_SELECTED_EVENT_FLOOR = 8
_HYBRID_SELECTED_PATH_CAP = 6
_HYBRID_TEMPORAL_PATH_CAP = 2
_HYBRID_PROFILE_PATH_CAP = 3
_MEMORY_ROUTER_GUIDED_MODES = {"guided", "route", "routing", "enforce"}
_MEMORY_ROUTER_FORCE_MODES = {"force", "forced"}
_MEMORY_ROUTER_DISABLED_MODES = {"", "off", "disabled", "none", "false", "0"}
_MEMORY_ROUTER_OBSERVE_MODES = {"observe", "observer", "telemetry", "shadow"}
_MEMORY_ROUTER_DEFAULT_THRESHOLD = 0.55
_MEMORY_ROUTER_DEFAULT_MARGIN = 0.08
_INJECTION_PLANNER_GUIDED_MODES = {"guided", "route", "routing", "enforce"}
_INJECTION_PLANNER_FORCE_MODES = {"force", "forced"}
_INJECTION_PLANNER_DISABLED_MODES = {"", "off", "disabled", "none", "false", "0"}
_INJECTION_PLANNER_OBSERVE_MODES = {"observe", "observer", "telemetry", "shadow"}
_TEMPORAL_LAYER_DISABLED_MODES = {"off", "disabled", "none", "false", "0"}
_TEMPORAL_ROUTER_DEFAULT_WRITER_MIN_CONFIDENCE = 0.72
_TEMPORAL_ROUTER_DEFAULT_QUERY_MIN_CONFIDENCE = 0.85
_TEMPORAL_ROUTER_DEFAULT_QUERY_INTENT_MIN_CONFIDENCE = 0.60
_EMBEDDER_INDEX_DISABLED_MODES = {"", "off", "disabled", "none", "false", "0"}
_EMBEDDER_INDEX_BGE_M3_MODES = {"bge", "bge_m3", "bge-m3", "baai_bge_m3", "baai/bge-m3"}
_EMBEDDER_INDEX_VERSION = "write_hash_sparse_v1"
_EMBEDDER_INDEX_BGE_M3_VERSION = "write_bge_m3_dense_sparse_v1"
_EMBEDDER_MODEL_CACHE: Dict[str, Any] = {}
_HYBRID_SYMBOLIC_STOPWORDS = {
"a",
"an",
"and",
"are",
"as",
"at",
"be",
"did",
"do",
"does",
"event",
"for",
"from",
"had",
"has",
"have",
"how",
"in",
"is",
"of",
"on",
"or",
"should",
"the",
"to",
"turn",
"was",
"were",
"what",
"when",
"where",
"which",
"who",
"why",
"will",
"with",
}
def _coerce_memory_router_scores(payload: Mapping[str, Any]) -> Dict[str, float]:
raw_scores = dict(payload.get("memory_router_scores", {}) or {})
scores: Dict[str, float] = {}
for layer in MEMORY_ROUTER_LAYERS:
try:
scores[layer] = float(raw_scores.get(layer, 0.0))
except (TypeError, ValueError):
scores[layer] = 0.0
return scores
def _memory_router_decision(
payload: Mapping[str, Any],
*,
mode: str,
threshold: float,
margin: float,
) -> Dict[str, Any]:
normalized_mode = _normalize(mode)
scores = _coerce_memory_router_scores(payload)
has_scores = bool(scores) and any(layer in dict(payload.get("memory_router_scores", {}) or {}) for layer in MEMORY_ROUTER_LAYERS)
if not has_scores:
return {
"memory_router_enabled": False,
"memory_router_guided": False,
"memory_router_reason": "no_router_scores",
"memory_router_scores": {},
"memory_router_top_layers": [],
"memory_router_active_layers": [],
"memory_router_score_spread": 0.0,
"memory_router_confidence": 0.0,
}
ranked_layers = [
layer
for layer, _ in sorted(
scores.items(),
key=lambda item: (-float(item[1]), item[0]),
)
]
score_values = list(scores.values())
score_spread = max(score_values) - min(score_values)
confidence = max(abs(float(score) - 0.5) for score in score_values)
resolved_threshold = max(0.0, min(1.0, float(threshold or _MEMORY_ROUTER_DEFAULT_THRESHOLD)))
resolved_margin = max(0.0, min(0.5, float(margin or _MEMORY_ROUTER_DEFAULT_MARGIN)))
active_layers = [
layer
for layer in ranked_layers
if float(scores.get(layer, 0.0)) >= resolved_threshold
]
confident = bool(active_layers) and (score_spread >= resolved_margin or confidence >= resolved_margin)
guided_requested = normalized_mode in _MEMORY_ROUTER_GUIDED_MODES
forced = normalized_mode in _MEMORY_ROUTER_FORCE_MODES
guided = bool(forced or (guided_requested and confident))
if forced and not active_layers and ranked_layers:
active_layers = [ranked_layers[0]]
if active_layers and "event" not in active_layers:
active_layers = ["event", *active_layers]
reason = "observe"
if normalized_mode in _MEMORY_ROUTER_DISABLED_MODES:
reason = "disabled"
elif normalized_mode in _MEMORY_ROUTER_OBSERVE_MODES:
reason = "observe"
elif forced:
reason = "forced"
elif guided_requested and not confident:
reason = "low_confidence"
elif guided:
reason = "guided"
return {
"memory_router_enabled": True,
"memory_router_guided": bool(guided and normalized_mode not in _MEMORY_ROUTER_DISABLED_MODES),
"memory_router_reason": reason,
"memory_router_scores": scores,
"memory_router_top_layers": ranked_layers,
"memory_router_active_layers": active_layers,
"memory_router_score_spread": round(float(score_spread), 6),
"memory_router_confidence": round(float(confidence), 6),
"memory_router_threshold": round(float(resolved_threshold), 6),
"memory_router_margin": round(float(resolved_margin), 6),
"memory_router_mode": normalized_mode or "observe",
}
def _memory_router_allows(decision: Mapping[str, Any], *layers: str) -> bool:
if not bool(decision.get("memory_router_guided")):
return True
requested = {_normalize(layer) for layer in layers if _normalize(layer)}
if not requested:
return True
if "event" in requested:
return True
active = {_normalize(layer) for layer in list(decision.get("memory_router_active_layers", []) or [])}
return bool(active & requested)
def _hybrid_symbolic_tokens(value: Any) -> List[str]:
return [
token
for token in _tokenize(value)
if token and token not in _HYBRID_SYMBOLIC_STOPWORDS and not re.fullmatch(r"\d+", token)
]
_PATH_UTILITY_STOPWORDS = set(_HYBRID_SYMBOLIC_STOPWORDS) | {
"actually",
"also",
"bit",
"feel",
"feels",
"get",
"guess",
"i",
"im",
"it",
"its",
"just",
"kind",
"like",
"maybe",
"me",
"more",
"much",
"my",
"really",
"something",
"still",
"think",
"thats",
"there",
"this",
"way",
"would",
}
def _path_utility_tokens(value: Any) -> List[str]:
return [
token
for token in _hybrid_symbolic_tokens(value)
if token and token not in _PATH_UTILITY_STOPWORDS
]
_PROFILE_QUERY_ALIAS_GROUPS: tuple[set[str], ...] = (
{
"accessory",
"accessories",
"bag",
"camera",
"cameras",
"equipment",
"flash",
"gear",
"lens",
"lenses",
"photo",
"photography",
"sony",
"tripod",
},
{
"app",
"apps",
"dashboard",
"interface",
"layout",
"panel",
"software",
"tool",
"tools",
"ui",
"workflow",
},
{
"diet",
"drink",
"food",
"meal",
"restaurant",
"snack",
"taste",
},
{
"background",
"career",
"job",
"occupation",
"position",
"previous",
"profession",
"role",
"worked",
"work",
},
)
_PROFILE_QUERY_GENERIC_TOKENS = set(_PATH_UTILITY_STOPWORDS) | {
"able",
"about",
"any",
"anything",
"based",
"best",
"can",
"complement",
"could",
"current",
"give",
"help",
"information",
"looking",
"make",
"need",
"please",
"recommend",
"recommendation",
"recommendations",
"should",
"some",
"suggest",
"suggestion",
"suggestions",
"tell",
"that",
"using",
"you",
}
def _profile_query_expanded_tokens(value: Any) -> set[str]:
tokens = set(_path_utility_tokens(value))
expanded = set(tokens)
for group in _PROFILE_QUERY_ALIAS_GROUPS:
if tokens & group:
expanded.update(group)
return expanded
def _profile_specific_tokens(tokens: Iterable[Any]) -> set[str]:
return {
_normalize(token)
for token in tokens
if _clean_text(token)
and _normalize(token) not in _PROFILE_QUERY_GENERIC_TOKENS
and not re.fullmatch(r"\d+", _normalize(token))
}
def _profile_hit_match_score(query_tokens: set[str], expanded_query_tokens: set[str], hit: MemoryHit) -> tuple[float, List[str], List[str]]:
metadata = dict(hit.metadata or {})
payload_parts: List[Any] = [
hit.category,
hit.relation,
hit.slot_key,
hit.value,
*list(hit.anchors or []),
metadata.get("profile_summary", ""),
metadata.get("profile_value", ""),
metadata.get("profile_type", ""),
metadata.get("profile_domain", ""),
metadata.get("profile_domain_label", ""),
metadata.get("semantic_slot", ""),
metadata.get("subject", ""),
metadata.get("extracted_subject", ""),
metadata.get("profile_cluster_domains", []),
metadata.get("profile_cluster_types", []),
metadata.get("profile_cluster_route_terms", []),
metadata.get("profile_route_terms", []),
metadata.get("profile_support_values", []),
]
payload_text = " ".join(str(item) for item in payload_parts)
record_raw_tokens = set(_path_utility_tokens(payload_text))
record_tokens = _profile_query_expanded_tokens(payload_text)
specific_query_tokens = _profile_specific_tokens(query_tokens)
raw_overlap_tokens = sorted(specific_query_tokens & _profile_specific_tokens(record_raw_tokens))
expanded_overlap_tokens = sorted(_profile_specific_tokens(expanded_query_tokens) & _profile_specific_tokens(record_tokens))
overlap_tokens = raw_overlap_tokens or expanded_overlap_tokens
overlap_ratio = float(len(overlap_tokens)) / float(max(1, len(specific_query_tokens)))
profile_type = _normalize(metadata.get("profile_type", ""))
source_kind = _normalize(hit.source_kind)
source_bonus = 0.0
if source_kind in {"public_dialog_preference", "public_dialog_goal", "public_dialog_profile"}:
source_bonus += 0.08
if bool(metadata.get("profile_candidate_status") == "consolidated"):
source_bonus += 0.04
if bool(metadata.get("profile_cluster_node")):
source_bonus -= 0.16
type_bonus = 0.0
if profile_type in {"preference", "goal", "setup", "usage_context"} and (
{"recommend", "suggest", "suited", "accessory", "accessories", "gear", "equipment", "setup", "current"} & query_tokens
):
type_bonus += 0.10
if profile_type in {"setup", "usage_context"} and {"current", "setup", "profile"} & query_tokens:
type_bonus += 0.08
raw_bonus = 0.18 * len(raw_overlap_tokens)
expanded_bonus = 0.08 * max(0, len(expanded_overlap_tokens) - len(raw_overlap_tokens))
score = raw_bonus + expanded_bonus + (0.58 * overlap_ratio) + source_bonus + type_bonus
return round(score, 6), overlap_tokens, raw_overlap_tokens
def _bounded_event_id_union(*groups: Iterable[Any], max_items: int) -> List[str]:
return _dedupe((item for group in groups for item in group), max_items=max(1, int(max_items)))
def _symbolic_recall_event_ids(
query: str,
runtime_graph: Mapping[str, Any],
*,
grouped_hits: Mapping[str, Sequence[MemoryHit]],
limit: int,
) -> List[str]:
question_features = extract_question_features(query)
query_tokens = set(_hybrid_symbolic_tokens(question_features.get("question_anchor_tokens", []) or query))
if not query_tokens:
return []
nodes_by_id = {
_clean_text(node.get("id", "")): dict(node)
for node in list(runtime_graph.get("nodes", []) or [])
if _clean_text(node.get("id", ""))
}
event_payloads: Dict[str, List[str]] = {}
for node_id, node in nodes_by_id.items():
if _clean_text(node.get("type", "")) == "event":
metadata = dict(node.get("metadata", {}) or {})
teacher_fields = dict(node.get("teacher_fields", {}) or {})
event_payloads.setdefault(node_id, []).extend(
[
node.get("text", ""),
node.get("speaker", ""),
node.get("slot_key", ""),
node.get("target_status", ""),
node.get("profile_type", ""),
node.get("profile_value", ""),
*teacher_fields.values(),
*metadata.values(),
]
)
for path in list(runtime_graph.get("paths", []) or []):
event_id = _clean_text(path.get("event_id", ""))
support_node = nodes_by_id.get(_clean_text(path.get("target", "")), {})
if event_id and support_node:
event_payloads.setdefault(event_id, []).append(support_node.get("text", ""))
for event_id, group_hits in grouped_hits.items():
payloads = event_payloads.setdefault(_clean_text(event_id), [])
for hit in list(group_hits or []):
metadata = dict(hit.metadata or {})
payloads.extend([hit.value, hit.slot_key, hit.category, hit.relation, *hit.anchors, *metadata.values()])
scored_events: List[tuple[str, float]] = []
for event_id, payloads in event_payloads.items():
event_tokens = set(_hybrid_symbolic_tokens(payloads))
overlap = query_tokens & event_tokens
if not overlap:
continue
overlap_ratio = float(len(overlap)) / float(max(1, len(query_tokens)))
event_node = nodes_by_id.get(event_id, {})
turn_index = int(event_node.get("turn_index", 0) or 0)
scored_events.append((event_id, (len(overlap) * 4.0) + overlap_ratio + min(turn_index, 1000) * 0.000001))
return [
event_id
for event_id, _ in sorted(scored_events, key=lambda item: (-float(item[1]), item[0]))
][: max(1, int(limit))]
_EMBEDDER_INDEX_METADATA_TEXT_KEYS = (
"raw_text",
"source_turn_text",
"source_span",
"event_phrase",
"event_summary",
"profile_value",
"profile_summary",
"profile_type",
"profile_domain",
"profile_domain_label",
"semantic_slot",
"target_status",
"subject",
"subject_signature",
"canonical_slot_key",
"resource_key",
"resolved_date",
"resolved_time_value",
"time_value",
"time_display_value",
"time_granularity",
"speaker",
"session_name",
"topic_label",
"topic_bucket_id",
"origin_query",
"writeback_class",
"depth_layer",
"memory_chain_depth_layer",
)
_EMBEDDER_INDEX_METADATA_LIST_KEYS = (
"topic_keywords",
"profile_route_terms",
"profile_cluster_route_terms",
"profile_support_values",
"evidence_anchors",
"support_memory_ids",
"support_fact_refs",
"support_path_refs",
)
def _embedder_index_enabled(mode: Any) -> bool:
return _normalize(mode) not in _EMBEDDER_INDEX_DISABLED_MODES
def _embedder_index_uses_bge_m3(mode: Any) -> bool:
return _normalize(mode).replace("-", "_") in {item.replace("-", "_") for item in _EMBEDDER_INDEX_BGE_M3_MODES}
def _embedder_index_version_for_mode(mode: Any) -> str:
return _EMBEDDER_INDEX_BGE_M3_VERSION if _embedder_index_uses_bge_m3(mode) else _EMBEDDER_INDEX_VERSION
def _embedder_index_text_items(value: Any, *, max_items: int = 64) -> List[str]:
items: List[str] = []
def visit(item: Any) -> None:
if len(items) >= max_items:
return
if item is None:
return
if isinstance(item, Mapping):
for key, nested in list(item.items()):
if len(items) >= max_items:
break
key_text = _clean_text(key)
if isinstance(nested, (str, int, float, bool)):
value_text = _clean_text(nested)[:800]
if value_text:
items.append(f"{key_text} {value_text}".strip())
else:
visit(nested)
return
if isinstance(item, (list, tuple, set)):
for nested in list(item):
if len(items) >= max_items:
break
visit(nested)
return
text = _clean_text(item)[:800]
if text:
items.append(text)
visit(value)
return items[:max_items]
def _embedder_index_term_weights(value: Any, *, max_terms: int = 96) -> Dict[str, float]:
text = _clean_text(value)
if not text:
return {}
counts: Dict[str, float] = {}
for token in _path_utility_tokens(text):
token = _normalize(token)
if not token or len(token) > 64:
continue
counts[token] = counts.get(token, 0.0) + 1.0
normalized_text = _normalize(text)
cjk_chars = [char for char in normalized_text if "\u4e00" <= char <= "\u9fff"]
for width, weight in ((2, 1.35), (3, 1.15)):
if len(cjk_chars) < width:
continue
for index in range(0, len(cjk_chars) - width + 1):
gram = "".join(cjk_chars[index : index + width])
if gram:
counts[gram] = counts.get(gram, 0.0) + weight
if not counts:
return {}
ranked = sorted(counts.items(), key=lambda item: (-float(item[1]), item[0]))[: max(1, int(max_terms or 1))]
norm = math.sqrt(sum(float(weight) * float(weight) for _, weight in ranked)) or 1.0
return {term: round(float(weight) / norm, 6) for term, weight in ranked}
def _embedder_dense_vectors_for_texts(texts: Sequence[str], *, mode: str) -> tuple[List[List[float]], Dict[str, Any]]:
normalized_mode = _normalize(mode)
if not _embedder_index_uses_bge_m3(normalized_mode):
return [[] for _ in texts], {"write_embedder_dense_enabled": False}
clean_texts = [_clean_text(text) for text in texts]
metadata: Dict[str, Any] = {
"write_embedder_dense_enabled": False,
"write_embedder_dense_backend": "bge_m3_transformers",
"write_embedder_dense_model": _clean_text(os.getenv("TMCRA_EMBEDDER_MODEL_PATH", "")) or "BAAI/bge-m3",
}
if not any(clean_texts):
metadata["write_embedder_dense_error"] = "empty_texts"
return [[] for _ in texts], metadata
try:
import torch # type: ignore
from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as exc:
metadata["write_embedder_dense_error"] = f"dependency_unavailable:{exc.__class__.__name__}"
return [[] for _ in texts], metadata
model_name = metadata["write_embedder_dense_model"]
device = _clean_text(os.getenv("TMCRA_EMBEDDER_DEVICE", ""))
if not device:
device = "cuda" if bool(getattr(torch, "cuda", None) and torch.cuda.is_available()) else "cpu"
try:
max_length = max(64, int(os.getenv("TMCRA_EMBEDDER_MODEL_MAX_LENGTH", "512") or 512))
except (TypeError, ValueError):
max_length = 512
cache_key = f"bge_m3::{model_name}::{device}::{max_length}"
try:
cached = _EMBEDDER_MODEL_CACHE.get(cache_key)
if cached is None:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()
cached = (tokenizer, model)
_EMBEDDER_MODEL_CACHE[cache_key] = cached
tokenizer, model = cached
encoded = tokenizer(
clean_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
encoded = {key: value.to(device) for key, value in encoded.items()}
with torch.no_grad():
output = model(**encoded)
hidden = output.last_hidden_state
mask = encoded.get("attention_mask")
if mask is not None:
mask = mask.unsqueeze(-1).expand(hidden.size()).float()
pooled = torch.sum(hidden * mask, dim=1) / torch.clamp(mask.sum(dim=1), min=1e-9)
else:
pooled = hidden[:, 0]
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
vectors = [
[round(float(value), 6) for value in row.detach().cpu().tolist()]
for row in pooled
]
metadata.update(
{
"write_embedder_dense_enabled": True,
"write_embedder_dense_device": device,
"write_embedder_dense_dim": int(len(vectors[0]) if vectors else 0),
"write_embedder_dense_max_length": int(max_length),
}
)
return vectors, metadata
except Exception as exc:
metadata["write_embedder_dense_error"] = f"{exc.__class__.__name__}:{_clean_text(exc)[:240]}"
return [[] for _ in texts], metadata
def _prewarm_embedder_dense_if_requested(*, mode: str) -> Dict[str, Any]:
flag = _normalize(os.getenv("TMCRA_EMBEDDER_PREWARM", ""))
if flag in _EMBEDDER_INDEX_DISABLED_MODES or flag not in {"1", "true", "yes", "on", "auto"}:
return {"embedder_prewarm_enabled": False}
normalized_mode = _normalize(mode)
if not _embedder_index_uses_bge_m3(normalized_mode):
return {
"embedder_prewarm_enabled": False,
"embedder_prewarm_reason": "mode_not_dense",
"embedder_prewarm_mode": normalized_mode or "off",
}
warmup_text = _clean_text(os.getenv("TMCRA_EMBEDDER_PREWARM_TEXT", "")) or "tmcra memory retrieval warmup"
vectors, metadata = _embedder_dense_vectors_for_texts([warmup_text], mode=normalized_mode)
return {
"embedder_prewarm_enabled": bool(vectors and vectors[0]),
"embedder_prewarm_mode": normalized_mode,
"embedder_prewarm_dense_enabled": bool(metadata.get("write_embedder_dense_enabled")),
"embedder_prewarm_dense_device": metadata.get("write_embedder_dense_device", ""),
"embedder_prewarm_dense_error": metadata.get("write_embedder_dense_error", ""),
}
def _embedder_index_record_text(record: SessionMemoryRecordV2, *, turn_text: str = "") -> str:
metadata = dict(record.metadata or {})
parts: List[str] = [
record.category,
record.slot_key,
record.value,
record.value,
record.value,
record.relation,
*list(record.anchor_concepts or []),
*list(record.anchor_concepts or []),
*list(record.evidence_anchors or []),
]
has_record_evidence_text = any(
_clean_text(metadata.get(key, ""))
for key in ("raw_text", "source_turn_text", "source_span", "event_phrase", "profile_value", "profile_summary")
)
if turn_text and not has_record_evidence_text:
parts.append(_clean_text(turn_text)[:1000])
for key in _EMBEDDER_INDEX_METADATA_TEXT_KEYS:
value = metadata.get(key)
if value:
parts.append(f"{key} {' '.join(_embedder_index_text_items(value, max_items=12))}".strip())
for key in _EMBEDDER_INDEX_METADATA_LIST_KEYS:
value = metadata.get(key)
if value:
parts.append(f"{key} {' '.join(_embedder_index_text_items(value, max_items=24))}".strip())
return _clean_text(" ".join(_clean_text(part) for part in parts if _clean_text(part)))
def _apply_write_embedder_index_to_graph(
graph: SessionMemoryGraphV2,
*,
stored_ids: Sequence[str],
turn_text: str,
turn_index: int,
mode: str,
max_terms: int,
) -> Dict[str, Any]:
normalized_mode = _normalize(mode)
index_version = _embedder_index_version_for_mode(normalized_mode)
metadata: Dict[str, Any] = {
"write_embedder_index_enabled": False,
"write_embedder_index_mode": normalized_mode or "off",
"write_embedder_index_version": index_version,
"write_embedder_index_record_count": 0,
}
if not _embedder_index_enabled(normalized_mode) or not stored_ids:
return metadata
indexed_ids: List[str] = []
index_rows: List[tuple[SessionMemoryRecordV2, str, Dict[str, float]]] = []
for memory_id in _dedupe(stored_ids):
record = getattr(graph, "records_by_id", {}).get(memory_id)
if record is None:
continue
index_text = _embedder_index_record_text(record, turn_text=turn_text)
terms = _embedder_index_term_weights(
index_text,
max_terms=max_terms,
)
if not terms:
continue
index_rows.append((record, index_text, terms))
dense_vectors, dense_metadata = _embedder_dense_vectors_for_texts(
[index_text for _, index_text, _ in index_rows],
mode=normalized_mode,
)
metadata.update(dense_metadata)
for row_index, (record, _, terms) in enumerate(index_rows):
dense_vector = dense_vectors[row_index] if row_index < len(dense_vectors) else []
record_metadata = dict(record.metadata or {})
record_metadata.update(
{
"write_embedder_index_enabled": True,
"write_embedder_index_mode": normalized_mode,
"write_embedder_index_version": index_version,
"write_embedder_index_source_turn": int(turn_index),
"write_embedder_index_term_count": int(len(terms)),
"write_embedder_index_terms": dict(terms),
"write_embedder_index_top_terms": list(terms.keys())[:24],
}
)
if dense_vector:
record_metadata.update(
{
"write_embedder_dense_enabled": True,
"write_embedder_dense_backend": dense_metadata.get("write_embedder_dense_backend", ""),
"write_embedder_dense_model": dense_metadata.get("write_embedder_dense_model", ""),
"write_embedder_dense_dim": int(len(dense_vector)),
"write_embedder_dense_vector": list(dense_vector),
}
)
elif _embedder_index_uses_bge_m3(normalized_mode):
record_metadata.update(
{
"write_embedder_dense_enabled": False,
"write_embedder_dense_error": dense_metadata.get("write_embedder_dense_error", "dense_vector_unavailable"),
}
)
record.metadata = record_metadata
indexed_ids.append(record.memory_id)
metadata.update(
{
"write_embedder_index_enabled": bool(indexed_ids),
"write_embedder_index_record_count": int(len(indexed_ids)),
"write_embedder_index_record_ids": list(indexed_ids[:24]),
}
)
return metadata
def _embedder_index_recall_event_ids(
query: str,
*,
grouped_hits: Mapping[str, Sequence[MemoryHit]],
mode: str,
limit: int,
max_terms: int,
) -> Dict[str, Any]:
normalized_mode = _normalize(mode)
index_version = _embedder_index_version_for_mode(normalized_mode)
metadata: Dict[str, Any] = {
"embedder_index_recall_enabled": False,
"embedder_index_recall_mode": normalized_mode or "off",
"embedder_index_recall_version": index_version,
"embedder_index_event_ids": [],
"embedder_index_event_scores": {},
"embedder_index_record_count": 0,
}
if not _embedder_index_enabled(normalized_mode):
return {"event_ids": [], "metadata": metadata}
query_terms = _embedder_index_term_weights(query, max_terms=max_terms)
query_vectors, query_dense_metadata = _embedder_dense_vectors_for_texts([query], mode=normalized_mode)
query_vector = query_vectors[0] if query_vectors else []
metadata.update(
{
"embedder_dense_recall_enabled": bool(query_vector),
"embedder_dense_recall_backend": query_dense_metadata.get("write_embedder_dense_backend", ""),
"embedder_dense_recall_model": query_dense_metadata.get("write_embedder_dense_model", ""),
"embedder_dense_recall_error": query_dense_metadata.get("write_embedder_dense_error", ""),
}
)
if not query_terms and not query_vector:
metadata["embedder_index_recall_reason"] = "empty_query_terms"
return {"event_ids": [], "metadata": metadata}
scored_events: List[tuple[str, float, int]] = []
indexed_record_count = 0
dense_record_count = 0
for event_id, group_hits in grouped_hits.items():
event_score = 0.0
event_turn = 0
for hit in list(group_hits or []):
hit_metadata = dict(hit.metadata or {})
raw_terms = hit_metadata.get("write_embedder_index_terms")
raw_vector = hit_metadata.get("write_embedder_dense_vector")
if not isinstance(raw_terms, Mapping) and not isinstance(raw_vector, list):
continue
indexed_record_count += int(isinstance(raw_terms, Mapping))
dense_record_count += int(isinstance(raw_vector, list) and bool(raw_vector))
sparse_score = 0.0
if isinstance(raw_terms, Mapping):
for term, query_weight in query_terms.items():
try:
sparse_score += float(query_weight) * float(raw_terms.get(term, 0.0) or 0.0)
except (TypeError, ValueError):
continue
dense_score = 0.0
if query_vector and isinstance(raw_vector, list) and raw_vector:
for query_value, record_value in zip(query_vector, raw_vector):
try:
dense_score += float(query_value) * float(record_value)
except (TypeError, ValueError):
continue
score = max(float(sparse_score), float(dense_score) + (0.15 * float(sparse_score) if dense_score > 0.0 else 0.0))
if score <= 0.0:
continue
event_score = max(event_score, float(score))
event_turn = max(event_turn, int(hit.turn_index or 0))
if event_score > 0.0:
scored_events.append((_clean_text(event_id), round(event_score, 6), event_turn))
scored_events.sort(key=lambda item: (-float(item[1]), -int(item[2]), item[0]))
selected = scored_events[: max(1, int(limit or 1))]
event_ids = [event_id for event_id, _, _ in selected if event_id]
event_scores = {event_id: score for event_id, score, _ in selected if event_id}
metadata.update(
{
"embedder_index_recall_enabled": True,
"embedder_index_event_ids": list(event_ids),
"embedder_index_event_scores": dict(event_scores),
"embedder_index_query_terms": list(query_terms.keys())[:32],
"embedder_index_record_count": int(indexed_record_count),
"embedder_dense_record_count": int(dense_record_count),
}
)
return {"event_ids": event_ids, "metadata": metadata}
_IDENTIFIER_GENERIC_TOKENS = {
"api",
"agent",
"code",
"codename",
"context",
"debug",
"goal",
"memory",
"model",
"name",
"project",
"retrieval",
"runtime",
"session",
"target",
"test",
"turn",
}
_IDENTIFIER_REQUEST_RE = re.compile(
r"\b(code\s*name|codename|identifier|alias|project\s+name|model\s+name|api\s+name)\b|"
r"(代号|编号|标识符|名称|名字|别名|项目名|模型名|接口名)",
flags=re.IGNORECASE,
)
_IDENTIFIER_TOKEN_RE = re.compile(r"\b[A-Za-z][A-Za-z0-9_-]{2,63}\b")
def _query_identifier_tokens(query: Any) -> List[str]:
text = _clean_text(query)
tokens: List[str] = []
for match in _IDENTIFIER_TOKEN_RE.finditer(text):
token = match.group(0)
lowered = token.lower()
if lowered in _IDENTIFIER_GENERIC_TOKENS:
continue
has_inner_upper = any(ch.isupper() for ch in token[1:])
has_digit = any(ch.isdigit() for ch in token)
has_joiner = "-" in token or "_" in token
if has_inner_upper or has_digit or has_joiner:
tokens.append(token)
return _dedupe(tokens, max_items=8)
def _query_requests_identifier_fact(query: Any) -> bool:
return bool(_IDENTIFIER_REQUEST_RE.search(_clean_text(query)))
def _hit_text_for_identifier_match(hit: MemoryHit) -> str:
metadata = dict(hit.metadata or {})
values = [
hit.memory_id,
hit.category,
hit.value,
hit.relation,
hit.slot_key,
hit.source_kind,
*list(hit.anchors or []),
metadata.get("raw_text", ""),
metadata.get("source_turn_text", ""),
metadata.get("source_span", ""),
metadata.get("event_phrase", ""),
metadata.get("profile_value", ""),
metadata.get("target_status", ""),
]
return " ".join(_clean_text(value) for value in values if _clean_text(value))
def _copy_hit_with_identifier_boost(hit: MemoryHit, *, score: float, reasons: Sequence[str], matched_tokens: Sequence[str]) -> MemoryHit:
metadata = dict(hit.metadata or {})
metadata.update(
{
"identifier_protected": True,
"identifier_protected_score": round(float(score), 6),
"identifier_protected_reasons": list(reasons)[:8],
"identifier_protected_matched_tokens": list(matched_tokens)[:8],
"identifier_protected_original_score": round(float(hit.score), 6),
}
)
return MemoryHit(
memory_id=hit.memory_id,
category=hit.category,
value=hit.value,
relation=hit.relation,
anchors=list(hit.anchors),
score=max(float(hit.score), round(1.0 + float(score), 6)),
source_kind=hit.source_kind,
slot_key=hit.slot_key,
state=hit.state,
turn_index=int(hit.turn_index),
metadata=metadata,
)
def _identifier_protected_hits(
*,
query: str,
final_hits: Sequence[MemoryHit],
candidate_hits: Sequence[MemoryHit],
top_k: int,
) -> Dict[str, Any]:
identifier_tokens = _query_identifier_tokens(query)
identifier_request = _query_requests_identifier_fact(query)
if not identifier_tokens and not identifier_request:
return {"enabled": False, "hits": list(final_hits), "promoted_hits": [], "metadata": {"identifier_protected_enabled": False}}
query_text = _clean_text(query)
pool: Dict[str, MemoryHit] = {}
for hit in list(final_hits) + list(candidate_hits):
key = hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}"
if key and key not in pool:
pool[key] = hit
scored: List[tuple[float, MemoryHit, List[str], List[str]]] = []
for hit in pool.values():
hit_text = _hit_text_for_identifier_match(hit)
hit_lower = hit_text.lower()
matched_tokens = [token for token in identifier_tokens if token.lower() in hit_lower]
reasons: List[str] = []
score = 0.0
if matched_tokens:
score += 12.0 + float(len(matched_tokens))
reasons.append("exact_identifier_match")
if identifier_request:
if any(term in hit_lower for term in ("codename", "code name", "identifier", "alias", "project_codename")):
score += 8.0
reasons.append("identifier_field_match")
if any(term in hit_text for term in ("代号", "编号", "标识符", "名称", "名字", "别名")):
score += 8.0
reasons.append("identifier_cjk_field_match")
if "项目" in query_text and "项目" in hit_text:
score += 1.5
reasons.append("project_term_match")
if score <= 0:
continue
scored.append((score, hit, matched_tokens, reasons))
scored.sort(key=lambda item: (item[0], float(item[1].score), int(item[1].turn_index or 0)), reverse=True)
promoted = [
_copy_hit_with_identifier_boost(hit, score=score, reasons=reasons, matched_tokens=matched_tokens)
for score, hit, matched_tokens, reasons in scored[:2]
]
if not promoted:
return {
"enabled": True,
"hits": list(final_hits),
"promoted_hits": [],
"metadata": {
"identifier_protected_enabled": True,
"identifier_query_tokens": identifier_tokens,
"identifier_request": bool(identifier_request),
"identifier_promoted_count": 0,
},
}
promoted_keys = {hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}" for hit in promoted}
merged = list(promoted)
for hit in final_hits:
key = hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}"
if key in promoted_keys:
continue
merged.append(hit)
if len(merged) >= max(1, int(top_k)):
break
return {
"enabled": True,
"hits": merged[: max(1, int(top_k))],
"promoted_hits": promoted,
"metadata": {
"identifier_protected_enabled": True,
"identifier_query_tokens": identifier_tokens,
"identifier_request": bool(identifier_request),
"identifier_promoted_count": len(promoted),
"identifier_promoted_ids": [hit.memory_id for hit in promoted],
},
}
def _trim_prompt_text(value: Any, *, max_chars: int = 220) -> str:
text = _clean_text(value)
if len(text) <= max_chars:
return text
return f"{text[: max(0, max_chars - 3)].rstrip()}..."
def _prompt_hit_payload(hit: MemoryHit) -> Dict[str, Any]:
return {
"memory_id": hit.memory_id,
"slot_key": hit.slot_key,
"category": hit.category,
"value": _trim_prompt_text(hit.value),
"relation": hit.relation,
"anchors": [_trim_prompt_text(anchor, max_chars=80) for anchor in list(hit.anchors)[:4]],
"score": round(float(hit.score), 6),
"state": hit.state,
"turn_index": int(hit.turn_index),
"source_kind": hit.source_kind,
}
def _prompt_record_payload(record: SessionMemoryRecordV2) -> Dict[str, Any]:
metadata = dict(record.metadata or {})
return {
"slot_key": record.slot_key,
"category": record.category,
"value": _trim_prompt_text(record.value),
"state": record.state,
"turn_index": int(record.turn_index),
"anchors": [_trim_prompt_text(anchor, max_chars=80) for anchor in list(record.anchor_concepts)[:4]],
"memory_role": _clean_text(metadata.get("memory_role", "")),
"authority": _clean_text(metadata.get("authority", "")),
}
def _graph_prompt_state_summary(graph: SessionMemoryGraphV2, retrieval: MemoryRetrieval) -> Dict[str, Any]:
active_slots: List[Dict[str, Any]] = []
for slot_key, record_id in list(graph.slot_heads.items())[:_GRAPH_PROMPT_MAX_ACTIVE_SLOTS]:
record = graph.records_by_id.get(record_id)
if record is None:
continue
active_slots.append(_prompt_record_payload(record))
top_hits = [_prompt_hit_payload(hit) for hit in list(retrieval.hits)[:_GRAPH_PROMPT_MAX_HITS]]
relation_preview = [
{
"from": _trim_prompt_text(item.get("from", ""), max_chars=72),
"to": _trim_prompt_text(item.get("to", ""), max_chars=72),
"relation": _clean_text(item.get("relation", "")),
}
for item in list(retrieval.relations)[:_GRAPH_PROMPT_MAX_RELATIONS]
]
summary = {
"records": len(graph.records_by_id),
"active_slots": len(graph.slot_heads),
"turn_index": int(graph.turn_index),
"noise_turn_count": int(graph.noise_turn_count),
"answer_support_events": len(graph.answer_support_log),
"top_hits": top_hits,
"active_slot_records": active_slots,
"relation_preview": relation_preview,
"context_truncated": False,
"truncation_reason": "",
}
truncated = False
truncation_reason = ""
while len(json.dumps(summary, ensure_ascii=False)) > _GRAPH_PROMPT_MAX_CHARS:
if len(summary["top_hits"]) > 1:
summary["top_hits"] = summary["top_hits"][:-1]
truncated = True
truncation_reason = "trimmed_top_hits"
continue
if len(summary["active_slot_records"]) > 1:
summary["active_slot_records"] = summary["active_slot_records"][:-1]
truncated = True
truncation_reason = "trimmed_active_slots"
continue
if len(summary["relation_preview"]) > 2:
summary["relation_preview"] = summary["relation_preview"][:-1]
truncated = True
truncation_reason = "trimmed_relations"
continue
break
summary["context_truncated"] = truncated
summary["truncation_reason"] = truncation_reason
return summary
def _state_stats(*, storage_bytes: int, retrieval_context_tokens: int, total_state_tokens: int, **extra: Any) -> Dict[str, Any]:
return {
**extra,
"storage_bytes": int(storage_bytes),
"context_token_estimate": int(retrieval_context_tokens),
"retrieval_context_token_estimate": int(retrieval_context_tokens),
"total_state_token_estimate": int(total_state_tokens),
}
def _relation_hit(hit: MemoryHit, *, weight_bias: float = 0.0) -> Dict[str, Any]:
if not hit.anchors:
return {}
anchor = hit.anchors[0]
if not anchor or anchor == hit.value:
return {}
return {
"from": anchor,
"to": hit.value,
"relation": hit.relation,
"weight": round(max(0.25, min(0.98, 0.42 + hit.score * 0.4 + weight_bias)), 6),
"source_kind": hit.source_kind,
"memory_id": hit.memory_id,
}
def _raw_hit_to_memory_hit(payload: Dict[str, Any]) -> MemoryHit:
metadata = dict(payload.get("metadata", {}) or {})
if payload.get("supersedes"):
metadata["supersedes"] = list(payload.get("supersedes", []) or [])
slot_key = stable_slot_key(
category=str(payload.get("category", "")),
value=str(payload.get("value", "")),
anchors=[str(anchor) for anchor in payload.get("anchor_concepts", payload.get("anchors", [])) or [] if _clean_text(anchor)],
slot_key=str(payload.get("slot_key", metadata.get("slot", ""))),
relation=str(payload.get("relation", "related_to")),
metadata=metadata,
)
return MemoryHit(
memory_id=str(payload.get("memory_id", "")),
category=str(payload.get("category", "")),
value=str(payload.get("value", "")),
relation=str(payload.get("relation", "related_to")),
anchors=[str(anchor) for anchor in payload.get("anchor_concepts", payload.get("anchors", [])) or [] if _clean_text(anchor)],
score=float(payload.get("score", payload.get("relevance", 0.0)) or 0.0),
source_kind=str(payload.get("source_kind", "memory")),
slot_key=slot_key,
state=str(payload.get("state", payload.get("metadata", {}).get("state", "active")) or "active"),
turn_index=int(payload.get("turn_index", 0) or 0),
metadata=metadata,
)
def _restore_hit_scores(hits: List[MemoryHit], scored_lookup: Dict[str, MemoryHit]) -> List[MemoryHit]:
restored: List[MemoryHit] = []
for hit in hits:
scored = scored_lookup.get(hit.memory_id)
if scored:
hit.score = max(float(hit.score), float(scored.score))
if not hit.anchors and scored.anchors:
hit.anchors = list(scored.anchors)
restored.append(hit)
return restored
def _current_subject_query(query: str) -> bool:
lowered = _normalize(query)
return bool(
_public_query_subject(query)
and any(marker in lowered for marker in ("right now", "current", "currently", "active", "now", "当前", "现在"))
)
def _record_subject_signatures(record: SessionMemoryRecordV2) -> set[str]:
metadata = dict(record.metadata or {})
signatures = {
_normalize(metadata.get("subject_signature", "")).replace("-", "_"),
_public_subject_signature(metadata.get("subject", "")),
}
canonical_slot_key = _clean_text(metadata.get("canonical_slot_key", "") or record.slot_key)
if ".subject." in canonical_slot_key:
signatures.add(_public_subject_signature(canonical_slot_key.split(".subject.", 1)[-1]))
if ".subject." in record.slot_key:
signatures.add(_public_subject_signature(record.slot_key.split(".subject.", 1)[-1]))
signatures.discard("")
return signatures
def _current_subject_protected_hits(
*,
query: str,
graph: SessionMemoryGraphV2,
final_hits: Sequence[MemoryHit],
top_k: int,
) -> Dict[str, Any]:
if not _current_subject_query(query):
return {
"hits": list(final_hits),
"metadata": {"current_subject_resolver_enabled": False},
}
subject = _public_query_subject(query)
subject_signature = _public_subject_signature(subject)
if not subject_signature:
return {
"hits": list(final_hits),
"metadata": {
"current_subject_resolver_enabled": True,
"current_subject_resolver_reason": "no_subject_signature",
},
}
def _candidate_record(record: SessionMemoryRecordV2 | None) -> bool:
return bool(
record is not None
and _clean_text(record.source_kind).startswith("public_dialog")
and _normalize(record.category) != "question"
and subject_signature in _record_subject_signatures(record)
)
slot_head_candidates = [
record
for slot_key, memory_id in graph.slot_heads.items()
for record in [graph.records_by_id.get(memory_id)]
if _candidate_record(record)
and subject_signature in _record_subject_signatures(record)
]
candidates = slot_head_candidates or [
record
for record in graph.records_by_id.values()
if record.state == "active" and _candidate_record(record)
]
candidates.sort(
key=lambda record: (
int(
_normalize((record.metadata or {}).get("target_status", "")) == "current"
or _normalize(record.relation) == "current_subject_value"
),
int(record.turn_index),
float(record.confidence),
float(record.salience),
),
reverse=True,
)
promoted: List[MemoryHit] = []
for index, record in enumerate(candidates[: max(1, min(2, int(top_k or 1)))], start=1):
metadata = dict(record.metadata or {})
metadata.update(
{
"current_subject_resolver": True,
"current_subject_resolver_rank": index,
"current_subject_query_subject": subject,
"current_subject_query_signature": subject_signature,
"public_subject_match": True,
"public_subject_overlap": 1.0,
}
)
promoted.append(
MemoryHit(
memory_id=record.memory_id,
category=record.category,
value=record.value,
relation=record.relation,
anchors=list(record.anchor_concepts),
score=max(float(record.confidence), float(record.salience), 1.75),
source_kind=record.source_kind,
slot_key=record.slot_key,
state=record.state,
turn_index=int(record.turn_index),
metadata=metadata,
)
)
if not promoted:
return {
"hits": list(final_hits),
"metadata": {
"current_subject_resolver_enabled": True,
"current_subject_resolver_reason": "no_active_subject_head",
"current_subject_query_subject": subject,
"current_subject_query_signature": subject_signature,
},
}
promoted_ids = {hit.memory_id for hit in promoted}
merged_tail: List[MemoryHit] = []
for hit in final_hits:
if hit.memory_id in promoted_ids:
continue
metadata = dict(hit.metadata or {})
hit_state = _normalize(hit.state)
same_subject = subject_signature in {
_normalize(metadata.get("subject_signature", "")).replace("-", "_"),
_public_subject_signature(metadata.get("subject", "")),
_public_subject_signature(hit.slot_key.split(".subject.", 1)[-1]) if ".subject." in hit.slot_key else "",
_public_subject_signature(_clean_text(metadata.get("canonical_slot_key", "")).split(".subject.", 1)[-1])
if ".subject." in _clean_text(metadata.get("canonical_slot_key", ""))
else "",
}
if same_subject and hit_state in {"superseded", "evidence", "historical", "stale", "false"}:
continue
merged_tail.append(hit)
merged = [*promoted, *merged_tail]
limit = max(int(top_k or 1), len(promoted))
return {
"hits": merged[:limit],
"metadata": {
"current_subject_resolver_enabled": True,
"current_subject_resolver_reason": "promoted_active_subject_head",
"current_subject_query_subject": subject,
"current_subject_query_signature": subject_signature,
"current_subject_promoted_memory_ids": [hit.memory_id for hit in promoted],
},
}
def _depth_chain_protected_hits(
*,
query: str,
graph: SessionMemoryGraphV2,
final_hits: Sequence[MemoryHit],
top_k: int,
) -> Dict[str, Any]:
seed_memory_ids = [hit.memory_id for hit in final_hits if hit.memory_id]
chain = graph.depth_chain_for_query(
query,
seed_memory_ids=seed_memory_ids,
top_k=max(3, min(8, int(top_k or 1))),
)
if not chain.get("enabled") or not chain.get("nodes"):
return {
"hits": list(final_hits),
"metadata": {
"memory_chain_enabled": bool(chain.get("enabled", False)),
"memory_chain_reason": _clean_text(chain.get("reason", "")),
"memory_chain_node_count": 0,
"memory_chain_edge_count": 0,
"memory_chain": chain,
},
}
seen = {hit.memory_id for hit in final_hits if hit.memory_id}
chain_hits: List[MemoryHit] = []
for rank, node in enumerate(list(chain.get("nodes", []) or []), start=1):
if not isinstance(node, Mapping):
continue
memory_id = _clean_text(node.get("memory_id", ""))
if not memory_id or memory_id in seen:
continue
payload = dict(node)
metadata = dict(payload.get("metadata", {}) or {})
metadata.update(
{
"memory_chain_protected": True,
"memory_chain_rank": int(rank),
"memory_chain_subject_signature": _clean_text(chain.get("subject_signature", "")),
"memory_chain_depth_layer": _clean_text(metadata.get("depth_layer", "")) or "core_view",
}
)
payload["metadata"] = metadata
payload["score"] = max(float(payload.get("score", 0.0) or 0.0), 0.62 - (rank * 0.01))
chain_hits.append(_raw_hit_to_memory_hit(payload))
seen.add(memory_id)
limit = max(int(top_k or 1), min(12, int(top_k or 1) + max(0, len(chain_hits))))
merged = [*list(final_hits), *chain_hits]
return {
"hits": merged[:limit],
"metadata": {
"memory_chain_enabled": True,
"memory_chain_reason": _clean_text(chain.get("reason", "")),
"memory_chain_subject_signature": _clean_text(chain.get("subject_signature", "")),
"memory_chain_node_count": int(chain.get("node_count", 0) or 0),
"memory_chain_edge_count": int(chain.get("edge_count", 0) or 0),
"memory_chain_depth_layers": list(chain.get("depth_layers", []) or []),
"memory_chain": chain,
},
}
def _is_public_dialog_hit(hit: MemoryHit) -> bool:
return _clean_text(hit.source_kind).startswith("public_dialog")
def _normalized_runtime_signature(prefix: str, value: str) -> str:
normalized = _normalize(value)
if not normalized:
return ""
return f"{prefix}{normalized.replace('|', '_').replace(':', '_')}"
def _runtime_event_key(hit: MemoryHit) -> str:
metadata = dict(hit.metadata or {})
explicit = _clean_text(metadata.get("event_id", ""))
if explicit:
return explicit
dia_id = _clean_text(metadata.get("dia_id", ""))
if dia_id:
return f"event::{dia_id}"
if not _is_public_dialog_hit(hit):
state_signature = _clean_text(metadata.get("state_signature", ""))
if state_signature:
return _normalized_runtime_signature("event::state::", state_signature)
memory_signature = _clean_text(metadata.get("memory_signature", ""))
if memory_signature:
return _normalized_runtime_signature("event::memory::", memory_signature)
slot_root = _public_slot_root(_clean_text(hit.slot_key))
if slot_root:
return slot_root
return _clean_text(hit.memory_id)
def _runtime_event_turn_index_from_id(event_id: str) -> int:
text = _clean_text(event_id)
if not text:
return 0
match = re.search(r"(?::|_)(\d+)$", text)
if match:
return int(match.group(1))
matches = re.findall(r"\d+", text)
return int(matches[-1]) if matches else 0
def _representative_event_hit(group_hits: Sequence[MemoryHit], *, query: str = "") -> MemoryHit | None:
semantic_source_kinds = {
"public_dialog_fact",
"public_dialog_preference",
"public_dialog_goal",
"public_dialog_constraint",
"public_dialog_status",
"public_dialog_profile",
"replacement_memory",
"session_memory",
}
query_tokens = set(_path_utility_tokens(query))
def rank(hit: MemoryHit) -> tuple[bool, float, bool, bool, bool, bool, float]:
metadata = dict(hit.metadata or {})
source_kind = _clean_text(hit.source_kind)
text_parts: List[Any] = [hit.value, hit.slot_key, hit.category]
if source_kind != "public_dialog_turn":
text_parts.extend([metadata.get("source_turn_text", ""), metadata.get("raw_text", "")])
text = " ".join(
_clean_text(item)
for item in text_parts
if _clean_text(item)
)
hit_tokens = set(_path_utility_tokens(text))
overlap = len(query_tokens & hit_tokens) if query_tokens else 0
has_number = bool(re.search(r"\b\d+\b", text))
is_semantic = source_kind in semantic_source_kinds or (
source_kind != "public_dialog_turn" and bool(_clean_text(metadata.get("memory_writer_role", "")))
)
direct_semantic_answer = bool(is_semantic and has_number and overlap >= 2)
query_score = float(overlap) + (0.75 if has_number and overlap else 0.0)
return (
direct_semantic_answer,
query_score,
is_semantic,
source_kind == "public_dialog_event",
source_kind == "public_dialog_turn",
source_kind in {"replacement_memory", "session_memory"},
float(hit.score),
)
ordered = sorted(
list(group_hits),
key=rank,
reverse=True,
)
return ordered[0] if ordered else None
def _event_record_hits_from_graph(graph: SessionMemoryGraphV2, event_id: str) -> List[MemoryHit]:
normalized_event_id = _clean_text(event_id)
if not normalized_event_id:
return []
hits: List[MemoryHit] = []
for record in graph.records_by_id.values():
metadata = dict(record.metadata or {})
if _clean_text(metadata.get("event_id", "")) != normalized_event_id:
continue
state = _normalize(record.state)
if state not in {"active", "parallel_active", "evidence"}:
continue
hits.append(
MemoryHit(
memory_id=record.memory_id,
category=record.category,
value=record.value,
relation=record.relation,
anchors=list(record.anchor_concepts),
score=max(float(record.confidence), float(record.salience), 0.01),
source_kind=record.source_kind,
slot_key=record.slot_key,
state=record.state,
turn_index=int(record.turn_index),
metadata=metadata,
)
)
return hits
def _hit_matches_path_support(path_type: str, hit: MemoryHit) -> bool:
metadata = dict(hit.metadata or {})
source_kind = _clean_text(hit.source_kind)
category = _clean_text(hit.category)
relation = _clean_text(hit.relation)
if path_type == "speaker_event_time":
return bool(
source_kind == "public_dialog_time"
or _clean_text(metadata.get("resolved_time_value", ""))
or _clean_text(metadata.get("resolved_date", ""))
or _clean_text(metadata.get("time_value", ""))
or _clean_text(metadata.get("time_display_value", ""))
or _clean_text(metadata.get("time_granularity", ""))
or relation == "event_date"
or category in {"time", "event_time"}
)
if path_type == "speaker_event_profile":
semantic_slot = _clean_text(metadata.get("semantic_slot", "")) or _clean_text(metadata.get("profile_type", ""))
return bool(
source_kind == "public_dialog_profile"
or semantic_slot in {"identity", "research_topic", "education", "occupation", "profile"}
or _clean_text(metadata.get("profile_value", ""))
or category == "profile"
)
if path_type == "speaker_event_status":
return bool(
_clean_text(metadata.get("target_status", ""))
or category in {"status", "stage_state"}
or relation == "status_of"
)
if path_type == "speaker_event_source_turn":
return bool(
source_kind in {"public_dialog_turn", "public_dialog_text", "public_dialog_auxiliary_evidence"}
or _clean_text(metadata.get("raw_text", ""))
or _clean_text(metadata.get("origin_query", ""))
or _clean_text(metadata.get("source_turn_text", ""))
or not _is_public_dialog_hit(hit)
)
return False
def _support_hit_for_path(path_type: str, group_hits: Sequence[MemoryHit]) -> MemoryHit | None:
matching_hits = [hit for hit in group_hits if _hit_matches_path_support(path_type, hit)]
if matching_hits:
matching_hits.sort(key=lambda item: (float(item.score), int(item.turn_index)), reverse=True)
return matching_hits[0]
representative = _representative_event_hit(group_hits)
return representative
def _path_support_node_id(path: Dict[str, Any]) -> str:
node_ids = list(path.get("node_ids", []) or [])
if len(node_ids) < 3:
return ""
return _clean_text(node_ids[2])
def _event_ids_from_hits(hits: Sequence[MemoryHit]) -> List[str]:
return _dedupe(
_clean_text(dict(hit.metadata or {}).get("event_id", ""))
for hit in hits
if _clean_text(dict(hit.metadata or {}).get("event_id", ""))
)
def _dia_ids_from_hits(hits: Sequence[MemoryHit]) -> List[str]:
return _dedupe(
_clean_text(dict(hit.metadata or {}).get("dia_id", ""))
for hit in hits
if _clean_text(dict(hit.metadata or {}).get("dia_id", ""))
)
def _final_hit_role_priority(hit: MemoryHit) -> int:
metadata = dict(hit.metadata or {})
source_kind = _clean_text(hit.source_kind)
semantic_source_kinds = {
"public_dialog_fact",
"public_dialog_preference",
"public_dialog_goal",
"public_dialog_constraint",
"public_dialog_status",
"public_dialog_profile",
"public_dialog_profile_cluster",
"replacement_memory",
"session_memory",
}
if source_kind in semantic_source_kinds or (
source_kind != "public_dialog_turn" and bool(_clean_text(metadata.get("memory_writer_role", "")))
):
return 0
if bool(metadata.get("profile_first_source_support")):
return 0
role = _clean_text(metadata.get("evidence_snippet_role", ""))
if role == "selected_path_support":
return 1
if role == "selected_event_representative":
return 2
if role == "selected_path_event":
return 3
return 4
def _coverage_preserving_final_hits(
final_hits: Sequence[MemoryHit],
*,
selected_event_ids: Sequence[str],
top_k: int,
) -> List[MemoryHit]:
"""Keep selected-event coverage before filling the remaining prompt budget.
Learned selection can emit both path-support and event-representative snippets
for the same event. A pure score sort can then drop another selected event at
the top-k boundary, which hides recall/rerank successes from the answer head.
"""
budget = max(1, int(top_k or 1))
hits = list(final_hits)
selected_order = [
_clean_text(event_id)
for event_id in selected_event_ids
if _clean_text(event_id)
][:budget]
if not selected_order:
return sorted(hits, key=lambda item: float(item.score), reverse=True)[:budget]
hits_by_event: Dict[str, List[MemoryHit]] = {}
for hit in hits:
event_id = _clean_text(dict(hit.metadata or {}).get("event_id", ""))
if event_id:
hits_by_event.setdefault(event_id, []).append(hit)
selected: List[MemoryHit] = []
used_memory_ids = set()
for event_id in selected_order:
candidates = [hit for hit in hits_by_event.get(event_id, []) if hit.memory_id not in used_memory_ids]
if not candidates:
continue
candidates.sort(key=lambda item: (_final_hit_role_priority(item), -float(item.score), item.memory_id))
chosen = candidates[0]
selected.append(chosen)
used_memory_ids.add(chosen.memory_id)
if len(selected) >= budget:
return selected
remaining = [
hit
for hit in hits
if hit.memory_id not in used_memory_ids
]
remaining.sort(key=lambda item: (-float(item.score), _final_hit_role_priority(item), item.memory_id))
for hit in remaining:
selected.append(hit)
if len(selected) >= budget:
break
return selected[:budget]
def _dominant_answer_type(question_analysis: Dict[str, Any], answer_type_scores: Dict[str, Any]) -> str:
normalized_scores = {
_clean_text(answer_type): float(value or 0.0)
for answer_type, value in dict(answer_type_scores or {}).items()
if _clean_text(answer_type)
}
if bool(question_analysis.get("is_temporal", False)) and normalized_scores.get("time", 0.0) >= (normalized_scores.get("abstain", 0.0) - 0.05):
return "time"
if not bool(question_analysis.get("is_temporal", False)) and normalized_scores:
non_time_scores = {
answer_type: score
for answer_type, score in normalized_scores.items()
if answer_type not in {"time", "abstain"}
}
if non_time_scores:
return max(non_time_scores.items(), key=lambda item: (float(item[1]), item[0]))[0]
if not normalized_scores:
return "time" if bool(question_analysis.get("is_temporal", False)) else "event_text"
return max(normalized_scores.items(), key=lambda item: (float(item[1]), item[0]))[0]
def _answer_type_preferred_path_types(question_analysis: Dict[str, Any], answer_type_scores: Dict[str, Any]) -> List[str]:
dominant_answer_type = _dominant_answer_type(question_analysis, answer_type_scores)
if dominant_answer_type == "time" or bool(question_analysis.get("is_temporal", False)):
return ["speaker_event_time", "speaker_event_source_turn", "speaker_event_status", "speaker_event_profile"]
if dominant_answer_type == "profile":
return ["speaker_event_profile", "speaker_event_source_turn", "speaker_event_status", "speaker_event_time"]
if dominant_answer_type == "multi_evidence":
return ["speaker_event_source_turn", "speaker_event_time", "speaker_event_profile", "speaker_event_status"]
return ["speaker_event_source_turn", "speaker_event_time", "speaker_event_profile", "speaker_event_status"]
def _reconciled_focused_answer_type(
question_analysis: Dict[str, Any],
answer_type_scores: Dict[str, Any],
model_answer_type: str,
) -> str:
model_type = _clean_text(model_answer_type)
dominant_type = _dominant_answer_type(question_analysis, answer_type_scores)
if bool(question_analysis.get("is_temporal", False)) and model_type not in {"", "time", "abstain"}:
return "time"
if model_type == "time" and not bool(question_analysis.get("is_temporal", False)):
return dominant_type if dominant_type != "time" else "event_text"
return model_type or dominant_type
def _path_type_is_focus_compatible(path_type: str, *, focused_answer_type: str, question_analysis: Dict[str, Any]) -> bool:
normalized_path_type = _clean_text(path_type)
normalized_answer_type = _clean_text(focused_answer_type)
if normalized_answer_type == "time" or bool(question_analysis.get("is_temporal", False)):
return normalized_path_type in {"speaker_event_time", "speaker_event_source_turn"}
if normalized_answer_type == "profile":
return normalized_path_type in {"speaker_event_profile", "speaker_event_source_turn"}
return True
def _rank_focus_compatible_path_ids(
*,
runtime_paths: Mapping[str, Dict[str, Any]],
selected_event_ids: Sequence[str],
path_scores: Mapping[str, Any],
event_scores: Mapping[str, Any],
temporal_scores: Mapping[str, Any],
question_analysis: Dict[str, Any],
answer_type_scores: Dict[str, Any],
focused_answer_type: str,
) -> List[str]:
selected_event_id_set = {_clean_text(event_id) for event_id in selected_event_ids if _clean_text(event_id)}
preferred_types = _answer_type_preferred_path_types(question_analysis, answer_type_scores)
if _clean_text(focused_answer_type) == "profile" and "speaker_event_profile" not in preferred_types:
preferred_types = ["speaker_event_profile", *preferred_types]
if (_clean_text(focused_answer_type) == "time" or bool(question_analysis.get("is_temporal", False))) and "speaker_event_time" not in preferred_types:
preferred_types = ["speaker_event_time", "speaker_event_source_turn", *preferred_types]
ranked: List[str] = []
seen = set()
for preferred_type in preferred_types:
candidates: List[tuple[str, float]] = []
for path_id, path in runtime_paths.items():
if _clean_text(path.get("type", "")) != preferred_type:
continue
event_id = _clean_text(path.get("event_id", ""))
if selected_event_id_set and event_id not in selected_event_id_set:
continue
if not _path_type_is_focus_compatible(preferred_type, focused_answer_type=focused_answer_type, question_analysis=question_analysis):
continue
support_node_id = _path_support_node_id(path)
score = (
float(event_scores.get(event_id, 0.0) or 0.0)
+ (0.20 * float(path_scores.get(path_id, 0.0) or 0.0))
+ (0.15 * float(temporal_scores.get(support_node_id, 0.0) or 0.0))
)
candidates.append((path_id, score))
for path_id, _ in sorted(candidates, key=lambda item: (-float(item[1]), item[0])):
if path_id in seen:
continue
seen.add(path_id)
ranked.append(path_id)
if ranked and (_clean_text(focused_answer_type) == "time" or bool(question_analysis.get("is_temporal", False))):
if preferred_type in {"speaker_event_time", "speaker_event_source_turn"}:
break
return ranked
def _repair_selected_paths_for_focus(
selected_path_ids: Sequence[str],
*,
runtime_paths: Mapping[str, Dict[str, Any]],
selected_event_ids: Sequence[str],
path_scores: Mapping[str, Any],
event_scores: Mapping[str, Any],
temporal_scores: Mapping[str, Any],
question_analysis: Dict[str, Any],
answer_type_scores: Dict[str, Any],
focused_answer_type: str,
limit: int,
) -> tuple[List[str], bool, str]:
normalized_selected = [_clean_text(path_id) for path_id in selected_path_ids if _clean_text(path_id)]
if not normalized_selected:
return [], False, ""
incompatible = [
path_id
for path_id in normalized_selected
if not _path_type_is_focus_compatible(
_clean_text(runtime_paths.get(path_id, {}).get("type", "")),
focused_answer_type=focused_answer_type,
question_analysis=question_analysis,
)
]
if not incompatible:
return normalized_selected, False, ""
compatible_ranked = _rank_focus_compatible_path_ids(
runtime_paths=runtime_paths,
selected_event_ids=selected_event_ids,
path_scores=path_scores,
event_scores=event_scores,
temporal_scores=temporal_scores,
question_analysis=question_analysis,
answer_type_scores=answer_type_scores,
focused_answer_type=focused_answer_type,
)
if not compatible_ranked:
return normalized_selected, False, ""
repaired = _dedupe(
[
*[path_id for path_id in normalized_selected if path_id not in incompatible],
*compatible_ranked,
]
)[: max(1, limit)]
if repaired == normalized_selected:
return normalized_selected, False, ""
return repaired, True, "replaced_focus_incompatible_model_paths"
def _runtime_node_by_id(runtime_graph: Mapping[str, Any]) -> Dict[str, Dict[str, Any]]:
return {
_clean_text(node.get("id", "")): dict(node)
for node in list(runtime_graph.get("nodes", []) or [])
if _clean_text(node.get("id", ""))
}
def _runtime_event_subject_signature(runtime_nodes: Mapping[str, Dict[str, Any]], event_id: str) -> str:
node = dict(runtime_nodes.get(_clean_text(event_id), {}) or {})
metadata = dict(node.get("metadata", {}) or {})
return _clean_text(node.get("subject_signature", "")) or _clean_text(metadata.get("subject_signature", ""))
def _path_utility_candidate_text(
path: Mapping[str, Any],
*,
runtime_nodes: Mapping[str, Dict[str, Any]],
grouped_hits: Mapping[str, Sequence[MemoryHit]],
) -> str:
event_id = _clean_text(path.get("event_id", ""))
support_node_id = _path_support_node_id(dict(path))
node_texts = [
_clean_text(dict(runtime_nodes.get(event_id, {}) or {}).get("text", "")),
_clean_text(dict(runtime_nodes.get(support_node_id, {}) or {}).get("text", "")),
]
path_type = _clean_text(path.get("type", ""))
support_hit = _support_hit_for_path(path_type, grouped_hits.get(event_id, []))
event_hit = _representative_event_hit(grouped_hits.get(event_id, []))
hit_texts = [
_clean_text(support_hit.value if support_hit is not None else ""),
_clean_text(event_hit.value if event_hit is not None else ""),
_runtime_source_turn_text(support_hit or event_hit, speaker=""),
]
return " ".join(_dedupe([*node_texts, *hit_texts], max_items=6))
def _path_utility_gate(
candidate_path_ids: Sequence[str],
*,
query: str,
runtime_graph: Mapping[str, Any],
runtime_paths: Mapping[str, Dict[str, Any]],
grouped_hits: Mapping[str, Sequence[MemoryHit]],
selected_path_ids: Sequence[str],
selected_event_ids_from_model: Sequence[str],
path_scores: Mapping[str, Any],
path_tunnel_support_scores: Mapping[str, Any],
question_analysis: Dict[str, Any],
focused_answer_type: str,
score_threshold: float,
limit: int,
) -> Dict[str, Any]:
runtime_nodes = _runtime_node_by_id(runtime_graph)
query_tokens = set(_path_utility_tokens(query))
anchor_event_ids = _dedupe(
[
*[
_clean_text(runtime_paths.get(path_id, {}).get("event_id", ""))
for path_id in selected_path_ids
if _clean_text(runtime_paths.get(path_id, {}).get("event_id", ""))
],
*[_clean_text(event_id) for event_id in selected_event_ids_from_model if _clean_text(event_id)],
]
)
anchor_subject_signatures = {
signature
for signature in (
_runtime_event_subject_signature(runtime_nodes, event_id)
for event_id in anchor_event_ids
)
if signature
}
direct_path_ids: List[str] = []
contrast_path_ids: List[str] = []
latent_path_ids: List[str] = []
noise_path_ids: List[str] = []
utility_scores: Dict[str, float] = {}
utility_roles: Dict[str, str] = {}
utility_reasons: Dict[str, str] = {}
utility_overlap_tokens: Dict[str, List[str]] = {}
for path_id in _dedupe(candidate_path_ids):
path = dict(runtime_paths.get(path_id, {}) or {})
if not path:
continue
path_type = _clean_text(path.get("type", ""))
event_id = _clean_text(path.get("event_id", ""))
candidate_text = _path_utility_candidate_text(
path,
runtime_nodes=runtime_nodes,
grouped_hits=grouped_hits,
)
candidate_tokens = set(_path_utility_tokens(candidate_text))
overlap_tokens = sorted(query_tokens & candidate_tokens)
overlap_ratio = float(len(overlap_tokens)) / float(max(1, min(len(query_tokens), len(candidate_tokens))))
support_score = float(path_tunnel_support_scores.get(path_id, 0.0) or 0.0)
decision_score = float(path_scores.get(path_id, support_score) or 0.0)
path_score = max(support_score, decision_score)
subject_signature = _runtime_event_subject_signature(runtime_nodes, event_id)
same_subject_chain = bool(subject_signature and subject_signature in anchor_subject_signatures)
focus_compatible = _path_type_is_focus_compatible(
path_type,
focused_answer_type=focused_answer_type,
question_analysis=question_analysis,
)
utility_score = path_score + (0.20 * overlap_ratio) + (0.04 if same_subject_chain else 0.0)
utility_scores[path_id] = round(float(utility_score), 6)
utility_overlap_tokens[path_id] = overlap_tokens[:12]
if not focus_compatible:
role = "drift_noise"
reason = "focus_incompatible"
elif overlap_ratio >= 0.18 and path_score >= score_threshold:
role = "direct_support"
reason = "query_overlap_and_tunnel_score"
elif same_subject_chain and path_score >= score_threshold:
role = "contrast_support"
reason = "same_subject_deep_chain"
elif path_score >= score_threshold:
role = "latent_context"
reason = "tunnel_score_without_current_turn_utility"
else:
role = "drift_noise"
reason = "below_utility_threshold"
utility_roles[path_id] = role
utility_reasons[path_id] = reason
if role == "direct_support":
direct_path_ids.append(path_id)
elif role == "contrast_support":
contrast_path_ids.append(path_id)
elif role == "latent_context":
latent_path_ids.append(path_id)
else:
noise_path_ids.append(path_id)
injected_path_ids = _dedupe([*direct_path_ids, *contrast_path_ids], max_items=max(0, int(limit)))
overflow_latent_path_ids = [
path_id
for path_id in [*direct_path_ids, *contrast_path_ids]
if path_id not in set(injected_path_ids)
]
latent_path_ids = _dedupe([*latent_path_ids, *overflow_latent_path_ids])
return {
"enabled": True,
"candidate_path_ids": list(_dedupe(candidate_path_ids)),
"direct_support_path_ids": list(direct_path_ids),
"contrast_support_path_ids": list(contrast_path_ids),
"latent_context_path_ids": list(latent_path_ids),
"drift_noise_path_ids": list(noise_path_ids),
"injected_path_ids": list(injected_path_ids),
"roles": dict(utility_roles),
"reasons": dict(utility_reasons),
"scores": dict(utility_scores),
"overlap_tokens": dict(utility_overlap_tokens),
"anchor_event_ids": list(anchor_event_ids),
"anchor_subject_signatures": sorted(anchor_subject_signatures),
}
def _calibrated_path_score(
*,
path: Dict[str, Any],
base_score: float,
temporal_scores: Dict[str, Any],
question_analysis: Dict[str, Any],
answer_type_scores: Dict[str, Any],
) -> float:
normalized_path_type = _clean_text(path.get("type", ""))
support_node_id = _path_support_node_id(path)
temporal_score = float(temporal_scores.get(support_node_id, 0.0) or 0.0)
normalized_answer_scores = {
_clean_text(answer_type): float(value or 0.0)
for answer_type, value in dict(answer_type_scores or {}).items()
if _clean_text(answer_type)
}
dominant_answer_type = _dominant_answer_type(question_analysis, normalized_answer_scores)
calibrated = float(base_score)
if dominant_answer_type == "time" or bool(question_analysis.get("is_temporal", False)):
if normalized_path_type == "speaker_event_time":
calibrated += (0.30 * temporal_score) + (0.12 * normalized_answer_scores.get("time", 0.0))
elif normalized_path_type == "speaker_event_source_turn":
calibrated -= 0.08 + (0.04 * normalized_answer_scores.get("time", 0.0))
elif normalized_path_type == "speaker_event_status":
calibrated -= 0.12
elif normalized_path_type == "speaker_event_profile":
calibrated -= 0.22 + (0.06 * normalized_answer_scores.get("time", 0.0))
elif dominant_answer_type == "profile":
if normalized_path_type == "speaker_event_profile":
calibrated += 0.12 * normalized_answer_scores.get("profile", 0.0)
elif normalized_path_type != "speaker_event_source_turn":
calibrated -= 0.08
else:
if normalized_path_type == "speaker_event_source_turn":
calibrated += 0.08 * max(
normalized_answer_scores.get("event_text", 0.0),
normalized_answer_scores.get("multi_evidence", 0.0),
)
elif normalized_path_type == "speaker_event_profile":
calibrated -= 0.04
return calibrated
def _group_metadata_value(
group_hits: Sequence[MemoryHit],
key: str,
*,
source_kinds: Sequence[str] = (),
) -> str:
source_kind_set = {_clean_text(item) for item in list(source_kinds) if _clean_text(item)}
candidates = [
hit
for hit in group_hits
if not source_kind_set or _clean_text(hit.source_kind) in source_kind_set
]
candidates.sort(key=lambda item: (float(item.score), int(item.turn_index)), reverse=True)
for hit in candidates:
metadata = dict(hit.metadata or {})
value = _clean_text(metadata.get(key, ""))
if value:
return value
return ""
def _runtime_event_sequence_key(session_name: str, turn_index: int, event_id: str) -> tuple[Any, ...]:
normalized_session = _clean_text(session_name)
session_number_match = re.search(r"(\d+)$", normalized_session)
if session_number_match:
return (0, int(session_number_match.group(1)), int(turn_index), _clean_text(event_id))
return (1, normalized_session, int(turn_index), _clean_text(event_id))
def _runtime_source_turn_text(hit: MemoryHit | None, *, speaker: str) -> str:
if hit is None:
return ""
metadata = dict(hit.metadata or {})
source_turn_text = _clean_text(metadata.get("source_turn_text", ""))
if source_turn_text:
return source_turn_text
raw_text = _clean_text(metadata.get("raw_text", ""))
if raw_text:
auxiliary_text = _clean_text(metadata.get("auxiliary_evidence_text", ""))
if auxiliary_text and auxiliary_text.lower() not in raw_text.lower():
return f"{raw_text}\nAuxiliary evidence: {auxiliary_text}"
return raw_text
origin_query = _clean_text(metadata.get("origin_query", ""))
if origin_query:
return origin_query
text = _clean_text(hit.value)
if _clean_text(hit.source_kind) == "public_dialog_turn":
text = re.sub(r"^\[[^\]]+\]\s*", "", text)
if _clean_text(speaker):
text = re.sub(rf"^{re.escape(_clean_text(speaker))}\s*:\s*", "", text, flags=re.IGNORECASE)
text = re.sub(r"^[A-Za-z][A-Za-z0-9_' -]{0,40}:\s*", "", text)
return text
def _runtime_event_signature(
*,
group_hits: Sequence[MemoryHit],
representative: MemoryHit,
speaker: str,
semantic_slot: str,
source_turn_hit: MemoryHit | None,
) -> str:
existing = _group_metadata_value(group_hits, "event_signature") or _clean_text(dict(representative.metadata or {}).get("event_signature", ""))
if existing:
return existing
event_phrase = _group_metadata_value(group_hits, "event_phrase") or _clean_text(dict(representative.metadata or {}).get("event_phrase", ""))
source_turn_text = _runtime_source_turn_text(source_turn_hit, speaker=speaker)
base_text = event_phrase or _clean_text(representative.value) or source_turn_text
if not base_text:
return ""
return compute_public_event_signature(
base_text,
speaker=_clean_text(speaker),
semantic_slot=_clean_text(semantic_slot),
) or _clean_text(base_text)
def _build_runtime_graph_from_hits(query: str, hits: Sequence[MemoryHit]) -> Dict[str, Any]:
nodes: List[Dict[str, Any]] = []
edges: List[Dict[str, Any]] = []
paths: List[Dict[str, Any]] = []
node_ids = set()
grouped_hits: Dict[str, List[MemoryHit]] = {}
ordered_events: List[tuple[Any, ...]] = []
event_typed_metadata_by_id: Dict[str, Dict[str, Any]] = {}
typed_tunnel_edges: List[Dict[str, Any]] = []
def add_node(node: Dict[str, Any]) -> None:
node_id = _clean_text(node.get("id", ""))
if not node_id or node_id in node_ids:
return
node_ids.add(node_id)
nodes.append(node)
for hit in hits:
event_id = _runtime_event_key(hit)
grouped_hits.setdefault(event_id, []).append(hit)
for event_id, group_hits in grouped_hits.items():
representative = _representative_event_hit(group_hits)
if representative is None:
continue
metadata = dict(representative.metadata or {})
event_time_hit = _support_hit_for_path("speaker_event_time", group_hits)
event_profile_hit = _support_hit_for_path("speaker_event_profile", group_hits)
source_turn_hit = _support_hit_for_path("speaker_event_source_turn", group_hits)
event_time_display_value = (
_group_metadata_value(group_hits, "time_display_value", source_kinds=("public_dialog_time",))
or _group_metadata_value(group_hits, "resolved_date")
or _group_metadata_value(group_hits, "time_display_value")
)
event_time_value = (
_group_metadata_value(group_hits, "resolved_time_value", source_kinds=("public_dialog_time",))
or _group_metadata_value(group_hits, "resolved_time_value")
or _group_metadata_value(group_hits, "time_value")
)
event_time_granularity = (
_group_metadata_value(group_hits, "time_granularity", source_kinds=("public_dialog_time",))
or _group_metadata_value(group_hits, "time_granularity")
)
event_profile_type = (
_group_metadata_value(group_hits, "profile_type", source_kinds=("public_dialog_profile",))
or _group_metadata_value(group_hits, "semantic_slot", source_kinds=("public_dialog_profile",))
or _group_metadata_value(group_hits, "profile_type")
or _group_metadata_value(group_hits, "semantic_slot")
)
profile_hit = _support_hit_for_path("speaker_event_profile", group_hits)
event_profile_value = (
_group_metadata_value(group_hits, "profile_value")
or _clean_text(profile_hit.value if profile_hit is not None else "")
or (_clean_text(representative.value) if event_profile_type else "")
)
event_target_status = _group_metadata_value(group_hits, "target_status")
depth_layer = (
_clean_text(metadata.get("depth_layer", ""))
or _group_metadata_value(group_hits, "depth_layer")
or _clean_text(metadata.get("memory_chain_depth_layer", ""))
or _group_metadata_value(group_hits, "memory_chain_depth_layer")
)
subject_signature = (
_clean_text(metadata.get("subject_signature", ""))
or _group_metadata_value(group_hits, "subject_signature")
or _clean_text(metadata.get("memory_chain_subject_signature", ""))
or _group_metadata_value(group_hits, "memory_chain_subject_signature")
)
session_name = (
_group_metadata_value(group_hits, "session_name")
or _group_metadata_value(group_hits, "session_key")
or _group_metadata_value(group_hits, "scope_id")
or "runtime_session"
)
speaker = (
_clean_text(metadata.get("speaker", ""))
or _group_metadata_value(group_hits, "speaker")
or _clean_text(metadata.get("subject_signature", ""))
or (_clean_text(representative.anchors[0]) if representative.anchors else "")
or "speaker"
)
event_turn_index = int(getattr(representative, "turn_index", 0) or 0)
semantic_slot = (
_group_metadata_value(group_hits, "semantic_slot")
or _clean_text(metadata.get("semantic_slot", ""))
or _clean_text(metadata.get("profile_type", ""))
or ("profile" if event_profile_type else _clean_text(representative.category))
or "event"
)
teacher_fields = {
"event_phrase": _clean_text(metadata.get("event_phrase", "")) or _clean_text(representative.value),
"semantic_slot": semantic_slot,
"target_status": event_target_status,
"time_expression_span": event_time_display_value,
"time_granularity": event_time_granularity,
"profile_type": event_profile_type,
}
base_event_signature = _runtime_event_signature(
group_hits=group_hits,
representative=representative,
speaker=speaker,
semantic_slot=event_profile_type or teacher_fields["semantic_slot"],
source_turn_hit=source_turn_hit,
)
typed_tunnel_metadata = merge_typed_metadata([metadata, *[dict(hit.metadata or {}) for hit in group_hits]])
event_typed_metadata_by_id[event_id] = typed_tunnel_metadata
typed_signature = typed_tunnel_signature_text(typed_tunnel_metadata)
event_text = _clean_text(metadata.get("event_phrase", "")) or representative.value
depth_prefix = " ".join(
item
for item in (
f"subject {subject_signature.replace('_', ' ')}" if subject_signature else "",
f"depth layer {depth_layer.replace('_', ' ')}" if depth_layer else "",
)
if item
)
runtime_event_text = f"{depth_prefix} {event_text}".strip() if depth_prefix else event_text
event_signature = (
compute_public_event_signature(
runtime_event_text,
speaker=_clean_text(speaker),
semantic_slot=_clean_text(event_profile_type or teacher_fields["semantic_slot"]),
)
or base_event_signature
)
if typed_signature and typed_signature not in event_signature:
event_signature = f"{event_signature} {typed_signature}".strip()
speaker_node_id = f"{event_id}:speaker:{_normalize(speaker).replace(' ', '_') or 'speaker'}"
add_node({"id": speaker_node_id, "type": "speaker", "text": speaker, "metadata": {"speaker": speaker}})
add_node(
{
"id": event_id,
"type": "event",
"text": runtime_event_text,
"speaker": speaker,
"turn_index": event_turn_index,
"session_name": session_name,
"dia_id": _clean_text(metadata.get("dia_id", "")),
"event_signature": event_signature,
"slot_key": _clean_text(representative.slot_key),
"state_signature": _clean_text(metadata.get("state_signature", "")),
"memory_signature": _clean_text(metadata.get("memory_signature", "")),
"target_status": event_target_status,
"time_granularity": event_time_granularity,
"time_value": event_time_value,
"time_display_value": event_time_display_value,
"profile_type": event_profile_type,
"profile_value": event_profile_value,
"depth_layer": depth_layer,
"subject_signature": subject_signature,
"tmcra_node_tags": list(typed_tunnel_metadata.get("tmcra_node_tags", []) or []),
"tmcra_path_tags": list(typed_tunnel_metadata.get("tmcra_path_tags", []) or []),
"tmcra_tunnel_roles": list(typed_tunnel_metadata.get("tmcra_tunnel_roles", []) or []),
"tmcra_tunnel_group_key": _clean_text(typed_tunnel_metadata.get("tmcra_tunnel_group_key", "")),
"teacher_fields": teacher_fields,
"metadata": {
"speaker": speaker,
"session_name": session_name,
"dia_id": _clean_text(metadata.get("dia_id", "")),
"slot_key": _clean_text(representative.slot_key),
"state_signature": _clean_text(metadata.get("state_signature", "")),
"memory_signature": _clean_text(metadata.get("memory_signature", "")),
"target_status": event_target_status,
"time_granularity": event_time_granularity,
"time_value": event_time_value,
"time_display_value": event_time_display_value,
"profile_type": event_profile_type,
"profile_value": event_profile_value,
"event_signature": event_signature,
"depth_layer": depth_layer,
"subject_signature": subject_signature,
**typed_tunnel_metadata,
},
}
)
edges.append({"id": f"{speaker_node_id}->{event_id}:speaker_of", "source": speaker_node_id, "target": event_id, "type": "speaker_of"})
time_node_ids: List[str] = []
profile_node_ids: List[str] = []
status_node_ids: List[str] = []
source_turn_node_ids: List[str] = []
if event_time_display_value or event_time_value:
time_node_id = f"{event_id}:time"
time_hit_metadata = dict((event_time_hit or representative).metadata or {})
add_node(
{
"id": time_node_id,
"type": "time",
"text": event_time_display_value or event_time_value,
"turn_index": int(getattr(event_time_hit, "turn_index", 0) or 0),
"time_display_value": event_time_display_value,
"time_value": event_time_value,
"time_granularity": event_time_granularity,
"metadata": {
"time_display_value": event_time_display_value,
"time_value": event_time_value,
"time_granularity": event_time_granularity,
"resolved_date": _clean_text(time_hit_metadata.get("resolved_date", "")),
},
}
)
edges.append({"id": f"{event_id}->{time_node_id}:time_of", "source": event_id, "target": time_node_id, "type": "time_of"})
time_node_ids.append(time_node_id)
if event_profile_value:
profile_node_id = f"{event_id}:profile:{_normalize(event_profile_type).replace(' ', '_') or 'profile'}"
add_node(
{
"id": profile_node_id,
"type": "profile",
"text": event_profile_value,
"turn_index": int(getattr(event_profile_hit or representative, "turn_index", 0) or 0),
"profile_type": event_profile_type,
"profile_value": event_profile_value,
"metadata": {
"profile_type": event_profile_type,
"profile_value": event_profile_value,
},
}
)
edges.append({"id": f"{event_id}->{profile_node_id}:profile_of", "source": event_id, "target": profile_node_id, "type": "profile_of"})
profile_node_ids.append(profile_node_id)
source_turn_text = _runtime_source_turn_text(source_turn_hit or representative, speaker=speaker)
if source_turn_text:
source_turn_node_id = f"{event_id}:source_turn"
add_node(
{
"id": source_turn_node_id,
"type": "source_turn",
"text": source_turn_text,
"turn_index": int(getattr(source_turn_hit or representative, "turn_index", 0) or 0),
"metadata": {
"speaker": speaker,
"dia_id": _clean_text(dict((source_turn_hit or representative).metadata or {}).get("dia_id", "")),
},
}
)
edges.append({"id": f"{event_id}->{source_turn_node_id}:supported_by_turn", "source": event_id, "target": source_turn_node_id, "type": "supported_by_turn"})
source_turn_node_ids.append(source_turn_node_id)
if event_target_status:
status_node_id = f"{event_id}:status"
add_node({"id": status_node_id, "type": "status", "text": event_target_status, "metadata": {"target_status": event_target_status}})
edges.append({"id": f"{event_id}->{status_node_id}:status_of", "source": event_id, "target": status_node_id, "type": "status_of"})
status_node_ids.append(status_node_id)
event_paths = build_default_path_templates(
event_id=event_id,
speaker_node_id=speaker_node_id,
time_node_ids=time_node_ids,
profile_node_ids=profile_node_ids,
status_node_ids=status_node_ids,
source_turn_node_ids=source_turn_node_ids,
)
for path in event_paths:
path_metadata = dict(path.get("metadata", {}) or {})
path_metadata["tmcra_path_tags"] = list(typed_tunnel_metadata.get("tmcra_path_tags", []) or [])
path_metadata["tmcra_tunnel_group_key"] = _clean_text(typed_tunnel_metadata.get("tmcra_tunnel_group_key", ""))
path["metadata"] = path_metadata
path["tmcra_path_tags"] = path_metadata["tmcra_path_tags"]
paths.extend(event_paths)
ordered_events.append(_runtime_event_sequence_key(session_name, event_turn_index, event_id))
ordered_event_ids = [event_id for _, _, _, event_id in sorted(ordered_events)]
for previous_event_id, next_event_id in zip(ordered_event_ids, ordered_event_ids[1:]):
typed_edge_tags = typed_edge_tags_between(
event_typed_metadata_by_id.get(previous_event_id, {}),
event_typed_metadata_by_id.get(next_event_id, {}),
)
edges.append(
{
"id": f"{previous_event_id}->{next_event_id}:same_session_next",
"source": previous_event_id,
"target": next_event_id,
"type": "same_session_next",
"metadata": {
"tmcra_edge_tags": typed_edge_tags,
"typed_tunnel_edge": bool(typed_edge_tags),
},
}
)
for index, source_event_id in enumerate(ordered_event_ids):
for target_event_id in ordered_event_ids[index + 1 : index + 12]:
typed_edge_tags = typed_edge_tags_between(
event_typed_metadata_by_id.get(source_event_id, {}),
event_typed_metadata_by_id.get(target_event_id, {}),
)
if not typed_edge_tags:
continue
typed_tunnel_edges.append(
{
"id": f"{source_event_id}->{target_event_id}:typed_tunnel",
"source": source_event_id,
"target": target_event_id,
"type": "typed_tunnel_candidate",
"metadata": {
"tmcra_edge_tags": typed_edge_tags,
"typed_tunnel_edge": True,
},
}
)
if len(typed_tunnel_edges) >= 64:
break
if len(typed_tunnel_edges) >= 64:
break
return {
"conversation_id": "runtime",
"query": query,
"nodes": nodes,
"edges": edges,
"typed_tunnel_edges": typed_tunnel_edges,
"paths": paths,
"grouped_hits": grouped_hits,
}
def _public_graph_hits(graph: SessionMemoryGraphV2) -> List[MemoryHit]:
public_hits: List[MemoryHit] = []
for record in graph.records_by_id.values():
if record.state != "active":
continue
if not _clean_text(record.source_kind).startswith("public_dialog"):
continue
metadata = dict(record.metadata or {})
public_hits.append(
MemoryHit(
memory_id=record.memory_id,
category=record.category,
value=record.value,
relation=record.relation,
anchors=list(record.anchor_concepts),
score=max(float(record.confidence), float(record.salience), 0.01),
source_kind=record.source_kind,
slot_key=record.slot_key,
state=record.state,
turn_index=int(record.turn_index),
metadata=metadata,
)
)
public_hits.sort(key=lambda item: (int(item.turn_index), float(item.score)), reverse=True)
return public_hits
_AUDIT_ANCHOR_QUERY_RE = re.compile(
r"(?i)\b(?:remember|recall|earlier|previous|previously|old|before|mentioned|said|quote|verbatim|original|"
r"that time|last time|bring back|return to|go back to|turn\s*\d+)\b|"
r"(?:\u7b2c\s*\d+\s*[\u8f6e\u6b21]|\d+\s*\u8f6e|\u4e4b\u524d|\u4ee5\u524d|\u521a\u521a|"
r"\u539f\u8bdd|\u90a3\u6b21|\u65e7\u8bdd\u9898|\u56de\u5230|\u63d0\u8d77|\u8bb0\u5f97)"
)
_AUDIT_TURN_ANCHOR_RE = re.compile(
r"(?i)\bturn\s*#?\s*(\d{1,6})\b|"
r"(?:\u7b2c\s*(\d{1,6})\s*[\u8f6e\u6b21]|\b(\d{1,6})\s*\u8f6e\b)"
)
_AUDIT_ANCHOR_STOPWORDS = set(_HYBRID_SYMBOLIC_STOPWORDS) | {
"about",
"again",
"back",
"bring",
"can",
"could",
"did",
"discuss",
"discussed",
"earlier",
"from",
"just",
"keep",
"mentioned",
"old",
"one",
"previous",
"previously",
"really",
"recall",
"remember",
"return",
"said",
"still",
"that",
"the",
"thing",
"think",
"this",
"turn",
"what",
"when",
"where",
"with",
"you",
}
_AUDIT_ANCHOR_GENERIC_MATCH_TOKENS = {
"body",
"coherence",
"continuity",
"fiction",
"fray",
"memories",
"memory",
"narrative",
"ourselves",
"physical",
"really",
"stories",
"story",
"tell",
"trust",
}
_AUDIT_ANCHOR_PHRASE_TOKEN_SETS = [
{"spine", "book"},
{"body", "spine"},
{"stories", "pages"},
{"story", "pages"},
{"islands", "self"},
{"archipelago", "self"},
{"loom", "body"},
{"braid", "body"},
]
def _audit_anchor_query(query: str) -> bool:
return bool(_AUDIT_ANCHOR_QUERY_RE.search(_clean_text(query)))
def _audit_anchor_turn_numbers(query: str) -> List[int]:
numbers: List[int] = []
for match in _AUDIT_TURN_ANCHOR_RE.finditer(_clean_text(query)):
for group in match.groups():
if not group:
continue
try:
value = int(group)
except Exception:
continue
if value > 0 and value not in numbers:
numbers.append(value)
return numbers
def _hit_event_id(hit: MemoryHit) -> str:
metadata = dict(hit.metadata or {})
event_id = _clean_text(metadata.get("event_id", ""))
if event_id:
return event_id
dia_id = _clean_text(metadata.get("dia_id", ""))
if dia_id:
return f"event::{dia_id}"
if int(hit.turn_index or 0) > 0:
return f"event::realchat:{int(hit.turn_index)}"
return ""
def _audit_anchor_hit_text(hit: MemoryHit) -> str:
metadata = dict(hit.metadata or {})
values = [
hit.value,
hit.category,
hit.source_kind,
hit.slot_key,
metadata.get("event_text", ""),
metadata.get("source_span", ""),
metadata.get("raw_text", ""),
]
return " ".join(_clean_text(value) for value in values if _clean_text(value))
def _audit_anchor_content_tokens(text: str) -> set[str]:
tokens = set()
for token in _tokenize(text):
norm = _normalize(token)
if not norm or norm in _AUDIT_ANCHOR_STOPWORDS:
continue
if len(norm) < 3 and not any("\u4e00" <= ch <= "\u9fff" for ch in norm):
continue
tokens.add(norm)
return tokens
def _audit_anchor_phrase_bonus(query_tokens: set[str], hit_tokens: set[str]) -> tuple[float, List[str]]:
matched_phrases: List[str] = []
bonus = 0.0
for phrase_tokens in _AUDIT_ANCHOR_PHRASE_TOKEN_SETS:
if phrase_tokens <= query_tokens and phrase_tokens <= hit_tokens:
matched_phrases.append("+".join(sorted(phrase_tokens)))
bonus += 5.0 if phrase_tokens == {"spine", "book"} else 3.0
return bonus, matched_phrases
def _copy_hit_with_audit_anchor_boost(hit: MemoryHit, *, score: float, reason: str, matched_tokens: Sequence[str]) -> MemoryHit:
metadata = dict(hit.metadata or {})
metadata.update(
{
"audit_anchor_protected": True,
"audit_anchor_reason": reason,
"audit_anchor_original_score": round(float(hit.score), 6),
"audit_anchor_score": round(float(score), 6),
"audit_anchor_matched_tokens": list(matched_tokens)[:20],
}
)
return MemoryHit(
memory_id=hit.memory_id,
category=hit.category,
value=hit.value,
relation=hit.relation,
anchors=list(hit.anchors),
score=max(float(hit.score), round(1.0 + float(score), 6)),
source_kind=hit.source_kind,
slot_key=hit.slot_key,
state=hit.state,
turn_index=int(hit.turn_index),
metadata=metadata,
)
def _audit_anchor_protected_hits(
*,
query: str,
final_hits: Sequence[MemoryHit],
candidate_hits: Sequence[MemoryHit],
metadata: Mapping[str, Any],
top_k: int,
) -> Dict[str, Any]:
if not _audit_anchor_query(query):
return {"enabled": False, "hits": list(final_hits), "promoted_hits": [], "metadata": {"audit_anchor_enabled": False}}
query_tokens = _audit_anchor_content_tokens(query)
if not query_tokens:
return {"enabled": True, "hits": list(final_hits), "promoted_hits": [], "metadata": {"audit_anchor_enabled": True, "audit_anchor_reason": "no_content_tokens"}}
explicit_turns = _audit_anchor_turn_numbers(query)
explicit_turn_window = set()
for number in explicit_turns:
explicit_turn_window.update({number - 1, number, number + 1})
symbolic_ids = list(dict.fromkeys(str(item) for item in dict(metadata or {}).get("symbolic_recall_event_ids", []) or []))
learned_ids = list(dict.fromkeys(str(item) for item in dict(metadata or {}).get("learned_recall_event_ids", []) or []))
selected_ids = set(str(item) for item in dict(metadata or {}).get("selected_event_ids", []) or [])
symbolic_rank = {event_id: index for index, event_id in enumerate(symbolic_ids, start=1)}
learned_rank = {event_id: index for index, event_id in enumerate(learned_ids, start=1)}
pool: Dict[str, MemoryHit] = {}
for hit in list(final_hits) + list(candidate_hits):
key = hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}"
if key and key not in pool:
pool[key] = hit
scored: List[tuple[float, str, List[str], MemoryHit]] = []
for hit in pool.values():
event_id = _hit_event_id(hit)
hit_tokens = _audit_anchor_content_tokens(_audit_anchor_hit_text(hit))
matched = sorted(query_tokens & hit_tokens)
distinctive_matched = [token for token in matched if token not in _AUDIT_ANCHOR_GENERIC_MATCH_TOKENS]
phrase_bonus, matched_phrases = _audit_anchor_phrase_bonus(query_tokens, hit_tokens)
if not matched and int(hit.turn_index or 0) not in explicit_turn_window:
continue
score = 0.0
reason_parts: List[str] = []
if int(hit.turn_index or 0) in explicit_turn_window:
score += 8.0
reason_parts.append("explicit_turn_anchor")
if matched:
weighted_overlap = 0.0
for token in matched:
weighted_overlap += 1.8 if token not in _AUDIT_ANCHOR_GENERIC_MATCH_TOKENS else 0.35
score += min(8.0, weighted_overlap)
score += len(matched) / max(1.0, float(min(len(query_tokens), len(hit_tokens))))
reason_parts.append("lexical_anchor_overlap")
if phrase_bonus > 0:
score += phrase_bonus
reason_parts.append("distinctive_phrase_anchor")
if event_id in symbolic_rank:
score += max(0.0, 2.0 - (symbolic_rank[event_id] - 1) * 0.08)
reason_parts.append("symbolic_recall_anchor")
if event_id in learned_rank:
score += max(0.0, 1.0 - (learned_rank[event_id] - 1) * 0.02)
reason_parts.append("learned_recall_anchor")
if event_id in selected_ids:
score -= 0.25
# For non-numeric old-topic probes, require a real content overlap so generic
# "earlier" phrasing does not promote unrelated old memories.
if not explicit_turns and len(distinctive_matched) < 1 and phrase_bonus <= 0:
continue
if score > 0:
scored.append((score, event_id, sorted(set(matched + matched_phrases)), hit))
scored.sort(key=lambda item: (item[0], -abs(int(item[3].turn_index or 0))), reverse=True)
max_promoted = 2 if not explicit_turns else 3
promoted: List[MemoryHit] = []
promoted_event_ids: set[str] = set()
for score, event_id, matched, hit in scored:
if event_id in promoted_event_ids:
continue
if any((hit.memory_id and hit.memory_id == current.memory_id) for current in final_hits[: max(1, top_k)]):
continue
reason = "explicit_turn_anchor" if int(hit.turn_index or 0) in explicit_turn_window else "old_topic_anchor"
promoted.append(_copy_hit_with_audit_anchor_boost(hit, score=score, reason=reason, matched_tokens=matched))
promoted_event_ids.add(event_id)
if len(promoted) >= max_promoted:
break
if not promoted:
return {
"enabled": True,
"hits": list(final_hits),
"promoted_hits": [],
"metadata": {
"audit_anchor_enabled": True,
"audit_anchor_turn_numbers": explicit_turns,
"audit_anchor_query_tokens": sorted(query_tokens)[:50],
"audit_anchor_promoted_event_ids": [],
},
}
merged: List[MemoryHit] = []
seen_keys: set[str] = set()
for hit in promoted + list(final_hits):
key = hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}"
if key in seen_keys:
continue
seen_keys.add(key)
merged.append(hit)
if len(merged) >= max(1, int(top_k)):
break
return {
"enabled": True,
"hits": merged,
"promoted_hits": promoted,
"metadata": {
"audit_anchor_enabled": True,
"audit_anchor_turn_numbers": explicit_turns,
"audit_anchor_query_tokens": sorted(query_tokens)[:50],
"audit_anchor_promoted_event_ids": [_hit_event_id(hit) for hit in promoted],
"audit_anchor_promoted_turns": [int(hit.turn_index) for hit in promoted],
"audit_anchor_promoted_hit_count": len(promoted),
},
}
def _learnable_graph_hits(graph: SessionMemoryGraphV2) -> List[MemoryHit]:
learnable_hits: List[MemoryHit] = []
for record in graph.records_by_id.values():
state = _normalize(record.state)
metadata = dict(record.metadata or {})
is_source_grounded_evidence = (
state == "evidence"
and (
_clean_text(record.source_kind).startswith("public_dialog")
or _clean_text(metadata.get("content_variant", "")) in {"source_turn", "llm_semantic_write"}
or _clean_text(metadata.get("write_path", "")) == "llm_semantic_writer_gate"
)
)
if state not in {"active", "parallel_active"} and not is_source_grounded_evidence:
continue
if _clean_text(record.slot_key).startswith("noise."):
continue
learnable_hits.append(
MemoryHit(
memory_id=record.memory_id,
category=record.category,
value=record.value,
relation=record.relation,
anchors=list(record.anchor_concepts),
score=max(float(record.confidence), float(record.salience), 0.01),
source_kind=record.source_kind,
slot_key=record.slot_key,
state=record.state,
turn_index=int(record.turn_index),
metadata=metadata,
)
)
learnable_hits.sort(
key=lambda item: (
_is_public_dialog_hit(item),
int(item.turn_index),
float(item.score),
),
reverse=True,
)
return learnable_hits
def _parse_structured_records(
payload: Dict[str, Any] | None,
*,
turn_index: int,
profile: TMCRAProfile | None = None,
) -> List[SessionMemoryRecordV2]:
profile = profile or TMCRAProfile()
records: List[SessionMemoryRecordV2] = []
structured_rows: List[tuple[str, Mapping[str, Any]]] = []
for raw in (payload or {}).get("replacement_memory_records", []) or []:
if isinstance(raw, Mapping):
structured_rows.append(("formal", raw))
for raw in (payload or {}).get("suspect_memory_records", []) or []:
if isinstance(raw, Mapping):
structured_rows.append(("suspect", raw))
for index, (buffer_state, raw) in enumerate(structured_rows):
if not isinstance(raw, dict):
continue
category = _clean_text(raw.get("category", "memory")) or "memory"
value = _clean_text(raw.get("value", ""))
if not value:
continue
anchors = _dedupe(raw.get("anchors", []) or [], max_items=8)
metadata = dict(raw.get("metadata", {}) or {})
slot_key = profile.stable_slot_key(
category=category,
value=value,
anchors=anchors,
slot_key=_clean_text(raw.get("slot_key", "")) or _clean_text(raw.get("slot", "")),
relation=_clean_text(raw.get("relation", "")),
metadata=metadata,
)
metadata = {
**metadata,
"memory_role": _clean_text(metadata.get("memory_role", "")) or "user",
"authority": _clean_text(metadata.get("authority", "")) or "source",
"canonical_slot_key": _clean_text(metadata.get("canonical_slot_key", "")) or slot_key,
"writeback_class": _clean_text(metadata.get("writeback_class", "")),
"origin_query": _clean_text(metadata.get("origin_query", "")),
"origin_answer_id": _clean_text(metadata.get("origin_answer_id", "")),
"support_memory_ids": _dedupe(metadata.get("support_memory_ids", []) or []),
"support_fact_refs": _dedupe(metadata.get("support_fact_refs", []) or []),
"support_path_refs": _dedupe(metadata.get("support_path_refs", []) or []),
"promotion_state": _clean_text(metadata.get("promotion_state", "")) or "none",
"memory_buffer_state": buffer_state,
}
memory_id = f"{slot_key}:{turn_index}:{index}"
default_source_kind = "suspect_memory" if buffer_state == "suspect" else "replacement_memory"
default_state = "suspect" if buffer_state == "suspect" else ("active" if bool(raw.get("active", True)) else "historical")
records.append(
SessionMemoryRecordV2(
memory_id=memory_id,
category=category,
slot_key=slot_key,
value=value,
relation=_clean_text(raw.get("relation", "")) or f"{category}_memory",
anchor_concepts=anchors,
evidence_anchors=anchors,
salience=float(raw.get("salience", 0.88 if category in {"goal", "constraint"} else 0.74) or 0.74),
confidence=float(raw.get("confidence", 0.82) or 0.82),
source_kind=_clean_text(raw.get("source_kind", "")) or default_source_kind,
turn_index=int(raw.get("turn_index", turn_index) or turn_index),
state=_clean_text(raw.get("state", "")) or default_state,
metadata=metadata,
)
)
return records
_WRITE_MARKERS = (
"goal update",
"goal seed",
"constraint update",
"constraint overwrite",
"constraint seed",
"preference update",
"preference overwrite",
"preference seed",
"terminology:",
"term seed",
"term overwrite",
"stage update",
"stage overwrite",
"path fact",
"fact:",
"memory update",
)
_OVERWRITE_MARKERS = (
"overwrite",
"replace",
"supersede",
"覆盖",
"替换",
"改成",
"更新为",
)
_TOPIC_BUCKET_CUE_GROUPS: tuple[tuple[tuple[str, ...], tuple[str, ...]], ...] = (
(
(
"食物",
"饮品",
"甜点",
"早餐",
"餐饮",
"口味",
"酸辣",
"热食",
"过敏",
"芒果",
"配料",
"推荐",
"吃",
"food",
"breakfast",
"allergy",
"mango",
),
("餐饮偏好安全", "食物偏好", "过敏约束", "口味偏好"),
),
(
(
"界面",
"首页",
"首屏",
"配色",
"营销",
"横幅",
"电商",
"网站",
"移动端",
"视觉",
"工具型",
"ui",
"homepage",
"ecommerce",
"mobile",
"design",
),
("界面产品设计", "电商页面", "视觉布局", "移动端体验"),
),
(
(
"api",
"writer",
"延迟",
"算法",
"耗时",
"评估",
"指标",
"调用",
"服务",
"key",
"runtime",
"latency",
"metric",
),
("api评估运行", "算法服务", "调用指标", "writer延迟"),
),
)
_TOPIC_BUCKET_STOPWORDS = {
"不要",
"不用",
"复述",
"原话",
"只说",
"关键",
"最关键",
"以后",
"需要",
"必须",
"可以",
"时候",
"这个",
"那个",
"现在",
"做",
"说",
"的",
"了",
"和",
"与",
"to",
"the",
"and",
"for",
}
def _topic_bucket_keywords(text: str, *, max_items: int = 18) -> List[str]:
normalized = _normalize(text)
tokens: List[str] = []
for cues, anchors in _TOPIC_BUCKET_CUE_GROUPS:
if any(cue and _normalize(cue) in normalized for cue in cues):
tokens.extend(anchors)
tokens.extend(cue for cue in cues if len(cue) >= 2)
tokens.extend(re.findall(r"[\u4e00-\u9fff]{2,8}", str(text or "")))
tokens.extend(_tokenize(text))
cleaned = []
for token in tokens:
value = _clean_text(token)
if not value:
continue
normalized_value = _normalize(value)
if normalized_value in _TOPIC_BUCKET_STOPWORDS:
continue
if len(value) < 2:
continue
cleaned.append(value)
return _dedupe(cleaned, max_items=max_items)
def _topic_bucket_id_from_keywords(keywords: Sequence[str]) -> str:
basis = "|".join(_normalize(item) for item in keywords[:8] if _clean_text(item))
if not basis:
basis = "general"
return "topic-" + str(uuid.uuid5(uuid.NAMESPACE_URL, f"tmcra-topic-bucket:{basis}"))[:12]
def _topic_bucket_label_from_keywords(keywords: Sequence[str]) -> str:
visible = [
_clean_text(item)
for item in keywords
if _clean_text(item) and not _normalize(item).startswith("topic:")
]
return " / ".join(visible[:3]) if visible else "动态话题"
def _topic_bucket_overlap_score(left_keywords: Sequence[str], right_keywords: Sequence[str]) -> float:
left = {_normalize(item) for item in left_keywords if _clean_text(item)}
right = {_normalize(item) for item in right_keywords if _clean_text(item)}
if not left or not right:
return 0.0
overlap = left & right
if not overlap:
return 0.0
return len(overlap) / max(1, min(len(left), len(right)))
def _explicit_topic_bucket_from_payload(answer_payload: Dict[str, Any] | None) -> Dict[str, Any]:
metadata = dict((answer_payload or {}).get("metadata", {}) or {})
explicit = metadata.get("topic_bucket")
if not isinstance(explicit, Mapping):
return {}
keywords = _dedupe(explicit.get("topic_keywords", explicit.get("keywords", [])) or [], max_items=24)
label = _clean_text(explicit.get("topic_label", explicit.get("label", "")))
if not keywords and label:
keywords = _topic_bucket_keywords(label, max_items=12)
bucket_id = _clean_text(explicit.get("topic_bucket_id", explicit.get("id", "")))
if not bucket_id:
bucket_id = _topic_bucket_id_from_keywords(keywords or [label])
return {
"topic_bucket_id": bucket_id,
"topic_label": label or _topic_bucket_label_from_keywords(keywords),
"topic_keywords": keywords,
"topic_confidence": float(explicit.get("confidence", explicit.get("topic_confidence", 0.92)) or 0.92),
"topic_assignment_source": "explicit_payload",
}
def _coerce_topic_bucket(metadata: Mapping[str, Any]) -> Dict[str, Any]:
bucket_id = _clean_text(metadata.get("topic_bucket_id", ""))
if not bucket_id:
nested = metadata.get("topic_bucket")
if isinstance(nested, Mapping):
bucket_id = _clean_text(nested.get("topic_bucket_id", nested.get("id", "")))
if not bucket_id:
return {}
keywords = _dedupe(metadata.get("topic_keywords", []) or [], max_items=32)
nested = metadata.get("topic_bucket")
if isinstance(nested, Mapping):
keywords = _dedupe([*keywords, *(nested.get("topic_keywords", nested.get("keywords", [])) or [])], max_items=32)
label = _clean_text(metadata.get("topic_label", ""))
if not label and isinstance(nested, Mapping):
label = _clean_text(nested.get("topic_label", nested.get("label", "")))
return {
"topic_bucket_id": bucket_id,
"topic_label": label or _topic_bucket_label_from_keywords(keywords),
"topic_keywords": keywords,
}
def _collect_topic_buckets(graph: SessionMemoryGraphV2) -> Dict[str, Dict[str, Any]]:
buckets: Dict[str, Dict[str, Any]] = {}
def merge(bucket: Mapping[str, Any], *, record_id: str = "", turn_index: int = 0) -> None:
bucket_id = _clean_text(bucket.get("topic_bucket_id", bucket.get("id", "")))
if not bucket_id:
return
current = buckets.setdefault(
bucket_id,
{
"topic_bucket_id": bucket_id,
"topic_label": _clean_text(bucket.get("topic_label", bucket.get("label", ""))),
"topic_keywords": [],
"record_ids": [],
"last_turn_index": 0,
},
)
current["topic_label"] = current.get("topic_label") or _clean_text(bucket.get("topic_label", bucket.get("label", "")))
current["topic_keywords"] = _dedupe(
[*list(current.get("topic_keywords", []) or []), *(bucket.get("topic_keywords", bucket.get("keywords", [])) or [])],
max_items=32,
)
if record_id:
current["record_ids"] = _dedupe([*list(current.get("record_ids", []) or []), record_id], max_items=200)
current["last_turn_index"] = max(int(current.get("last_turn_index", 0) or 0), int(turn_index or 0))
for record in getattr(graph, "records_by_id", {}).values():
metadata = dict(record.metadata or {})
bucket = _coerce_topic_bucket(metadata)
if bucket:
merge(bucket, record_id=record.memory_id, turn_index=int(record.turn_index))
for turn in list(getattr(graph, "turn_log", []) or []):
metadata = dict(turn.get("metadata", {}) if isinstance(turn, Mapping) else getattr(turn, "metadata", {}) or {})
bucket = _coerce_topic_bucket(metadata)
if bucket:
merge(bucket, turn_index=int(turn.get("turn_index", 0) if isinstance(turn, Mapping) else getattr(turn, "turn_index", 0) or 0))
for bucket in buckets.values():
if not bucket.get("topic_label"):
bucket["topic_label"] = _topic_bucket_label_from_keywords(bucket.get("topic_keywords", []) or [])
return buckets
def _assign_topic_bucket_for_text(
graph: SessionMemoryGraphV2,
text: str,
*,
answer_payload: Dict[str, Any] | None = None,
turn_index: int = 0,
create: bool = True,
) -> Dict[str, Any]:
explicit = _explicit_topic_bucket_from_payload(answer_payload)
if explicit:
explicit["turn_index"] = int(turn_index or 0)
return explicit
keywords = _topic_bucket_keywords(text, max_items=24)
if not keywords:
keywords = _dedupe([_clean_text(text)[:40]], max_items=1)
buckets = _collect_topic_buckets(graph)
best_bucket: Dict[str, Any] | None = None
best_score = 0.0
for bucket in buckets.values():
score = _topic_bucket_overlap_score(keywords, bucket.get("topic_keywords", []) or [])
if score > best_score:
best_bucket = bucket
best_score = score
if best_bucket and best_score >= 0.22:
merged_keywords = _dedupe([*list(best_bucket.get("topic_keywords", []) or []), *keywords], max_items=32)
return {
"topic_bucket_id": best_bucket["topic_bucket_id"],
"topic_label": best_bucket.get("topic_label") or _topic_bucket_label_from_keywords(merged_keywords),
"topic_keywords": merged_keywords,
"topic_confidence": round(min(0.97, 0.68 + best_score * 0.24), 6),
"topic_assignment_source": "reused_by_overlap",
"topic_match_score": round(best_score, 6),
"turn_index": int(turn_index or 0),
}
bucket_id = _topic_bucket_id_from_keywords(keywords)
return {
"topic_bucket_id": bucket_id,
"topic_label": _topic_bucket_label_from_keywords(keywords),
"topic_keywords": keywords,
"topic_confidence": 0.72 if create else 0.58,
"topic_assignment_source": "created_by_dialog" if create else "query_probe",
"topic_match_score": round(best_score, 6),
"turn_index": int(turn_index or 0),
}
def _apply_topic_bucket_to_records(records: List[SessionMemoryRecordV2], topic_bucket: Mapping[str, Any]) -> None:
if not records or not topic_bucket:
return
bucket_id = _clean_text(topic_bucket.get("topic_bucket_id", ""))
if not bucket_id:
return
label = _clean_text(topic_bucket.get("topic_label", "")) or "动态话题"
keywords = _dedupe(topic_bucket.get("topic_keywords", []) or [], max_items=32)
for record in records:
metadata = dict(record.metadata or {})
metadata.update(
{
"topic_bucket_id": bucket_id,
"topic_label": label,
"topic_keywords": keywords,
"topic_confidence": float(topic_bucket.get("topic_confidence", 0.72) or 0.72),
"topic_assignment_source": _clean_text(topic_bucket.get("topic_assignment_source", "")) or "created_by_dialog",
}
)
record.metadata = metadata
record.anchor_concepts = _dedupe([*list(record.anchor_concepts or []), f"topic:{label}", *keywords[:8]], max_items=24)
evidence_anchors = list(metadata.get("evidence_anchors", []) or [])
metadata["evidence_anchors"] = _dedupe([*evidence_anchors, f"topic:{label}", *keywords[:8]], max_items=24)
def _last_topic_turn(graph: SessionMemoryGraphV2) -> Dict[str, Any]:
for turn in reversed(list(getattr(graph, "turn_log", []) or [])):
metadata = dict(turn.get("metadata", {}) if isinstance(turn, Mapping) else getattr(turn, "metadata", {}) or {})
bucket = _coerce_topic_bucket(metadata)
if bucket:
bucket["turn_index"] = int(turn.get("turn_index", 0) if isinstance(turn, Mapping) else getattr(turn, "turn_index", 0) or 0)
return bucket
return {}
def _add_topic_bridge_edges(
graph: SessionMemoryGraphV2,
*,
previous_topic: Mapping[str, Any],
current_topic: Mapping[str, Any],
current_record_ids: Sequence[str],
turn_index: int,
evidence: str,
) -> Dict[str, Any]:
previous_id = _clean_text(previous_topic.get("topic_bucket_id", ""))
current_id = _clean_text(current_topic.get("topic_bucket_id", ""))
if not previous_id or not current_id or previous_id == current_id or not current_record_ids:
return {"topic_bridge_edge_count": 0}
buckets = _collect_topic_buckets(graph)
previous_record_ids = list((buckets.get(previous_id, {}) or {}).get("record_ids", []) or [])[-4:]
if not previous_record_ids:
return {"topic_bridge_edge_count": 0}
edge_count = 0
evidence_text = _clean_text(evidence)[:240]
for source_id in previous_record_ids:
for target_id in list(current_record_ids)[:4]:
if source_id == target_id:
continue
edge = SessionMemoryEdgeV2(
edge_id=f"{source_id}->{target_id}:topic_bridge:{previous_id}->{current_id}",
source_memory_id=source_id,
target_memory_id=target_id,
edge_type="topic_bridge",
score=0.54,
model_score=0.0,
evidence_turn=int(turn_index or 0),
evidence=evidence_text,
metadata={
"from_topic_bucket_id": previous_id,
"to_topic_bucket_id": current_id,
"from_topic_label": _clean_text(previous_topic.get("topic_label", "")),
"to_topic_label": _clean_text(current_topic.get("topic_label", "")),
"bridge_reason": "adjacent_dialog_topic_transition",
},
)
graph._upsert_memory_edge(edge)
edge_count += 1
return {
"topic_bridge_edge_count": edge_count,
"topic_bridge_from": previous_id,
"topic_bridge_to": current_id,
}
def _add_dialogue_tunnel_edges(
graph: SessionMemoryGraphV2,
*,
current_topic: Mapping[str, Any],
current_record_ids: Sequence[str],
turn_index: int,
evidence: str,
) -> Dict[str, Any]:
current_id = _clean_text(current_topic.get("topic_bucket_id", ""))
if not current_id or not current_record_ids:
return {"dialogue_tunnel_edge_count": 0}
buckets = _collect_topic_buckets(graph)
edge_count = 0
evidence_text = _clean_text(evidence)[:240]
source_ids: List[tuple[str, str, str]] = []
for bucket_id, bucket in sorted(
buckets.items(),
key=lambda item: int(item[1].get("last_turn_index", 0) or 0),
reverse=True,
):
if bucket_id == current_id:
continue
for source_id in list(bucket.get("record_ids", []) or [])[-2:]:
if source_id not in current_record_ids:
source_ids.append((bucket_id, _clean_text(bucket.get("topic_label", "")), source_id))
if len(source_ids) >= 6:
break
for source_bucket_id, source_label, source_id in source_ids[:6]:
for target_id in list(current_record_ids)[:2]:
if source_id == target_id:
continue
edge = SessionMemoryEdgeV2(
edge_id=f"{source_id}->{target_id}:dialogue_tunnel:{source_bucket_id}->{current_id}",
source_memory_id=source_id,
target_memory_id=target_id,
edge_type="dialogue_tunnel",
score=0.24,
model_score=0.0,
evidence_turn=int(turn_index or 0),
evidence=evidence_text,
metadata={
"from_topic_bucket_id": source_bucket_id,
"to_topic_bucket_id": current_id,
"from_topic_label": source_label,
"to_topic_label": _clean_text(current_topic.get("topic_label", "")),
"bridge_reason": "high_resistance_dialogue_level_tunnel",
},
)
graph._upsert_memory_edge(edge)
edge_count += 1
return {"dialogue_tunnel_edge_count": edge_count}
def _topic_adjacent_bucket_ids(graph: SessionMemoryGraphV2, bucket_id: str) -> set[str]:
adjacent: set[str] = set()
if not bucket_id:
return adjacent
for edge in getattr(graph, "memory_edges", {}).values():
if _normalize(edge.edge_type) != "topic_bridge":
continue
metadata = dict(edge.metadata or {})
left = _clean_text(metadata.get("from_topic_bucket_id", ""))
right = _clean_text(metadata.get("to_topic_bucket_id", ""))
if left == bucket_id and right:
adjacent.add(right)
if right == bucket_id and left:
adjacent.add(left)
return adjacent
def _dialogue_tunnel_bucket_ids(graph: SessionMemoryGraphV2, bucket_id: str) -> set[str]:
adjacent: set[str] = set()
if not bucket_id:
return adjacent
for edge in getattr(graph, "memory_edges", {}).values():
if _normalize(edge.edge_type) != "dialogue_tunnel":
continue
metadata = dict(edge.metadata or {})
left = _clean_text(metadata.get("from_topic_bucket_id", ""))
right = _clean_text(metadata.get("to_topic_bucket_id", ""))
if left == bucket_id and right:
adjacent.add(right)
if right == bucket_id and left:
adjacent.add(left)
return adjacent
def _topic_bridge_requested(query: str) -> bool:
text = _normalize(query)
if not text:
return False
bridge_markers = (
"关联",
"联系",
"链条",
"脉络",
"延展",
"深入",
"对比",
"整合",
"整体",
"上下文",
"刚才",
"之前",
"上面",
"跨话题",
"隧穿",
"related",
"connect",
"compare",
"context",
"chain",
)
return any(marker in text for marker in bridge_markers)
def _dialogue_tunnel_requested(query: str) -> bool:
text = _normalize(query)
if not text:
return False
markers = (
"跨话题",
"跨对话",
"不同话题",
"不同对话",
"历史对话",
"长期脉络",
"整体脉络",
"所有相关",
"全局",
"全局记忆",
"远一点",
"更深",
"深层关联",
"对话级隧穿",
"记忆隧穿",
"cross topic",
"cross-topic",
"cross dialogue",
"cross-session",
"global context",
"long range",
)
return any(marker in text for marker in markers)
def _topic_bucket_record_to_hit(
record: SessionMemoryRecordV2,
*,
query_topic: Mapping[str, Any],
rank: int,
rescue_kind: str = "topic_bucket",
) -> MemoryHit:
metadata = dict(record.metadata or {})
dialogue_tunnel = _normalize(rescue_kind) == "dialogue_tunnel"
metadata.update(
{
"topic_bucket_rescue": not dialogue_tunnel,
"dialogue_tunnel_rescue": dialogue_tunnel,
"topic_bucket_rescue_rank": int(rank),
"topic_bucket_query_id": _clean_text(query_topic.get("topic_bucket_id", "")),
"topic_bucket_query_label": _clean_text(query_topic.get("topic_label", "")),
"topic_bucket_same": not dialogue_tunnel,
"topic_bucket_bridge": False,
"topic_bucket_bridge_allowed": False,
"topic_bucket_dialogue_tunnel_allowed": dialogue_tunnel,
"topic_bucket_overlap": 0.0 if dialogue_tunnel else 1.0,
}
)
category = _normalize(record.category)
value_text = _normalize(record.value)
hardish = (
category == "constraint"
or _normalize(metadata.get("memory_type", "")) == "hard_constraint"
or _normalize(metadata.get("durability", "")) == "hard"
or _normalize(metadata.get("conflict_policy", "")) == "must_preserve"
or any(marker in value_text for marker in ("过敏", "必须", "避开", "禁止", "不能", "must", "avoid", "allergy"))
)
base_score = max(float(record.confidence), float(record.salience), 0.62)
if dialogue_tunnel:
if hardish:
base_score += 0.42
metadata.setdefault("memory_type", "hard_constraint")
metadata.setdefault("durability", "hard")
metadata.setdefault("conflict_policy", "must_preserve")
elif category == "preference":
base_score += 0.22
metadata.setdefault("memory_type", "durable_preference")
metadata.setdefault("durability", "long_term")
else:
base_score += 0.12
metadata["dialogue_tunnel_resistance"] = "high"
elif hardish:
base_score += 1.35
metadata.setdefault("memory_type", "hard_constraint")
metadata.setdefault("durability", "hard")
metadata.setdefault("conflict_policy", "must_preserve")
elif category == "preference":
base_score += 0.72
metadata.setdefault("memory_type", "durable_preference")
metadata.setdefault("durability", "long_term")
else:
base_score += 0.38
return MemoryHit(
memory_id=record.memory_id,
category=record.category,
value=record.value,
relation=record.relation,
anchors=list(record.anchor_concepts),
score=base_score,
source_kind=record.source_kind,
slot_key=record.slot_key,
state=record.state,
turn_index=int(record.turn_index),
metadata=metadata,
)
def _profile_query_rescue_hits(
graph: SessionMemoryGraphV2,
query: str,
*,
top_k: int,
) -> List[MemoryHit]:
query_raw_tokens = set(_path_utility_tokens(query))
query_tokens = _profile_query_expanded_tokens(query)
intent = infer_profile_query_intent(query)
if not bool(intent.get("enabled")):
return []
rescued: List[tuple[float, MemoryHit]] = []
for record in getattr(graph, "records_by_id", {}).values():
metadata = dict(record.metadata or {})
if record.state != "active":
continue
if not is_profile_layer_record(
category=record.category,
source_kind=record.source_kind,
semantic_slot=metadata.get("semantic_slot", ""),
metadata=metadata,
):
continue
delta, reason = profile_query_score_delta(
query=query,
query_tokens=query_tokens,
category=record.category,
source_kind=record.source_kind,
semantic_slot=metadata.get("semantic_slot", ""),
value=record.value,
anchors=record.anchor_concepts,
metadata=metadata,
)
if delta <= 0:
continue
match_score, overlap_tokens, raw_overlap_tokens = _profile_hit_match_score(
query_raw_tokens,
query_tokens,
MemoryHit(
memory_id=record.memory_id,
category=record.category,
value=record.value,
relation=record.relation,
anchors=list(record.anchor_concepts),
score=max(float(record.confidence), float(record.salience), 0.01),
source_kind=record.source_kind,
slot_key=record.slot_key,
state=record.state,
turn_index=int(record.turn_index),
metadata=metadata,
),
)
if match_score <= 0.0:
continue
if match_score < 0.34 and not raw_overlap_tokens:
continue
metadata.update(
{
"profile_query_rescue": True,
"profile_query_rescue_reason": reason or "profile_route",
"profile_query_match_score": round(match_score, 6),
"profile_query_overlap_tokens": list(overlap_tokens),
"profile_query_raw_overlap_tokens": list(raw_overlap_tokens),
"topic_bucket_profile_route_preserved": True,
"match_reason": ",".join(_dedupe([metadata.get("match_reason", ""), reason or "profile_route"], max_items=4)),
}
)
hit = MemoryHit(
memory_id=record.memory_id,
category=record.category,
value=record.value,
relation=record.relation,
anchors=list(record.anchor_concepts),
score=max(float(record.confidence), float(record.salience), 0.62) + float(delta) + float(match_score),
source_kind=record.source_kind,
slot_key=record.slot_key,
state=record.state,
turn_index=int(record.turn_index),
metadata=metadata,
)
rescued.append((match_score, hit))
rescued.sort(key=lambda item: (float(item[0]), float(item[1].score), int(item[1].turn_index)), reverse=True)
return [
hit
for _, hit in rescued[: max(1, min(24, int(top_k or 1) * 3))]
]
def _memory_hit_from_record(record: SessionMemoryRecordV2, *, score: float | None = None, metadata: Mapping[str, Any] | None = None) -> MemoryHit:
record_metadata = {**dict(record.metadata or {}), **dict(metadata or {})}
return MemoryHit(
memory_id=record.memory_id,
category=record.category,
value=record.value,
relation=record.relation,
anchors=list(record.anchor_concepts),
score=max(float(record.confidence), float(record.salience), 0.01) if score is None else float(score),
source_kind=record.source_kind,
slot_key=record.slot_key,
state=record.state,
turn_index=int(record.turn_index),
metadata=record_metadata,
)
_FACET_NUMERIC_QUERY_TOKENS = {
"amount",
"count",
"counts",
"duration",
"durations",
"many",
"much",
"number",
"quantity",
"sum",
"total",
"totals",
"weeks",
"week",
"hours",
"hour",
"dollars",
"dollar",
"tenants",
"tickets",
}
_FACET_TEMPORAL_QUERY_TOKENS = {
"after",
"before",
"date",
"deadline",
"end",
"finish",
"finished",
"start",
"started",
"time",
"when",
}
def _facet_query_pack_hits(
graph: SessionMemoryGraphV2,
query: str,
final_hits: Sequence[MemoryHit],
*,
top_k: int,
) -> Dict[str, Any]:
query_tokens = set(_path_utility_tokens(query))
if not query_tokens:
return {"hits": list(final_hits), "metadata": {"facet_query_pack_enabled": False, "facet_query_pack_reason": "empty_query_tokens"}}
numeric_query = bool(query_tokens & _FACET_NUMERIC_QUERY_TOKENS)
temporal_query = bool(query_tokens & _FACET_TEMPORAL_QUERY_TOKENS)
if not numeric_query and not temporal_query and not any("facet" in _normalize(token) for token in query_tokens):
return {"hits": list(final_hits), "metadata": {"facet_query_pack_enabled": False, "facet_query_pack_reason": "no_facet_intent"}}
candidate_rows: List[tuple[float, SessionMemoryRecordV2, SessionMemoryRecordV2 | None, Dict[str, Any]]] = []
for record in getattr(graph, "records_by_id", {}).values():
metadata = dict(record.metadata or {})
if _normalize(metadata.get("content_variant", "")) != "event_facet_write":
continue
if record.state not in {"active", "parallel_active", "evidence"}:
continue
facet_type = _normalize(metadata.get("facet_type", ""))
parent_slot_key = _clean_text(metadata.get("facet_parent_slot_key", ""))
parent = next(
(
candidate
for candidate in getattr(graph, "records_by_id", {}).values()
if _clean_text(candidate.slot_key).lower() == parent_slot_key.lower()
),
None,
)
parent_text = " ".join(
[
_clean_text(parent.value if parent else ""),
" ".join(parent.anchor_concepts if parent else []),
_clean_text(dict(parent.metadata or {}).get("source_span", "") if parent else ""),
]
)
facet_text = " ".join(
[
record.value,
" ".join(record.anchor_concepts or []),
_clean_text(metadata.get("facet_type", "")),
_clean_text(metadata.get("facet_role", "")),
_clean_text(metadata.get("facet_source_span", "")),
parent_text,
]
)
facet_tokens = set(_path_utility_tokens(facet_text))
parent_tokens = set(_path_utility_tokens(parent_text))
overlap_tokens = query_tokens & facet_tokens
parent_overlap_tokens = query_tokens & parent_tokens
unit_overlap = bool(query_tokens & set(_path_utility_tokens(record.value)))
score = 0.0
if overlap_tokens:
score += min(0.72, len(overlap_tokens) / max(1.0, len(query_tokens)) * 1.2)
if parent_overlap_tokens:
score += min(0.72, len(parent_overlap_tokens) / max(1.0, len(query_tokens)) * 1.35)
if numeric_query and facet_type == "numeric":
score += 0.42
if unit_overlap:
score += 0.42
if temporal_query and facet_type == "temporal":
score += 0.32
if facet_type == "entity" and parent_overlap_tokens:
score += 0.26
if parent is not None and _normalize(dict(parent.metadata or {}).get("content_variant", "")) == "llm_semantic_write":
score += 0.08
if score < 0.58:
continue
candidate_rows.append(
(
round(min(2.75, 1.18 + score), 6),
record,
parent,
{
"facet_query_pack_overlap_tokens": sorted(overlap_tokens)[:12],
"facet_query_pack_parent_overlap_tokens": sorted(parent_overlap_tokens)[:12],
"facet_query_pack_unit_overlap": bool(unit_overlap),
"facet_query_pack_score": round(score, 6),
},
)
)
if not candidate_rows:
return {
"hits": list(final_hits),
"metadata": {
"facet_query_pack_enabled": True,
"facet_query_pack_inserted_hit_count": 0,
"facet_query_pack_candidate_count": 0,
},
}
candidate_rows.sort(key=lambda item: (float(item[0]), int(item[1].turn_index)), reverse=True)
selected = candidate_rows[: max(4, min(18, int(top_k or 1) * 2))]
packed_hits: List[MemoryHit] = []
for score, facet_record, parent, extra_metadata in selected:
facet_metadata = {
**extra_metadata,
"facet_query_pack": True,
"evidence_snippet_role": "facet_query_attribute",
}
packed_hits.append(_memory_hit_from_record(facet_record, score=score, metadata=facet_metadata))
if parent is not None:
packed_hits.append(
_memory_hit_from_record(
parent,
score=max(1.12, score - 0.04),
metadata={
**extra_metadata,
"facet_query_pack": True,
"evidence_snippet_role": "facet_parent_event",
"facet_query_pack_child_id": facet_record.memory_id,
},
)
)
merged: List[MemoryHit] = []
seen_ids: set[str] = set()
for hit in [*packed_hits, *list(final_hits)]:
if hit.memory_id and hit.memory_id in seen_ids:
continue
if hit.memory_id:
seen_ids.add(hit.memory_id)
merged.append(hit)
return {
"hits": merged,
"metadata": {
"facet_query_pack_enabled": True,
"facet_query_pack_candidate_count": len(candidate_rows),
"facet_query_pack_inserted_hit_count": len(packed_hits),
"facet_query_pack_numeric_query": bool(numeric_query),
"facet_query_pack_temporal_query": bool(temporal_query),
},
}
_UNIT_COVERAGE_COUNT_TOKENS = {
"amount",
"amounts",
"count",
"counts",
"cost",
"costs",
"dollar",
"dollars",
"how",
"many",
"much",
"number",
"minimum",
"maximum",
"paid",
"percent",
"percentage",
"price",
"prices",
"sale",
"sales",
"sell",
"sold",
"total",
"totals",
"sum",
"value",
"valued",
"values",
"worth",
"items",
"projects",
"events",
"things",
}
_UNIT_COVERAGE_TEMPORAL_TOKENS = {
"ago",
"after",
"before",
"between",
"consecutive",
"date",
"days",
"day",
"first",
"last",
"months",
"month",
"order",
"passed",
"since",
"weeks",
"week",
}
_MULTI_UNIT_CHAIN_TEMPORAL_COMPARISON_TOKENS = {
"after",
"before",
"between",
"consecutive",
"earlier",
"first",
"later",
"last",
"order",
"since",
}
_UNIT_COVERAGE_QUERY_DROP_TOKENS = {
"date",
"fri",
"friday",
"mon",
"monday",
"question",
"sat",
"saturday",
"sun",
"sunday",
"thu",
"thursday",
"tue",
"tuesday",
"wed",
"wednesday",
}
_MULTI_UNIT_CHAIN_DISABLED_MODES = {"", "off", "disabled", "none", "false", "0"}
_MULTI_UNIT_CHAIN_COUNT_TOKENS = {
*_UNIT_COVERAGE_COUNT_TOKENS,
"which",
"each",
"all",
"both",
}
_MULTI_UNIT_CHAIN_FOCUS_DROP_TOKENS = {
*_MULTI_UNIT_CHAIN_COUNT_TOKENS,
*_UNIT_COVERAGE_TEMPORAL_TOKENS,
"i",
"me",
"my",
"mine",
"am",
"is",
"are",
"was",
"were",
"need",
"needs",
"needed",
"currently",
"did",
"does",
"fri",
"friday",
"mon",
"monday",
"or",
"question",
"sat",
"saturday",
"sun",
"sunday",
"thu",
"thursday",
"tue",
"tuesday",
"wed",
"wednesday",
"what",
"which",
"who",
}
_MULTI_UNIT_CHAIN_FACET_TYPES = {"action", "entity", "numeric", "role", "evidence_role", "state", "temporal"}
_MULTI_UNIT_CHAIN_UNIT_KINDS = {
"action_unit",
"target_entity",
"numeric_quantity",
"leadership_role",
"participation_role",
"state_status",
"temporal_anchor",
"evidence_role",
"profile_shadow_unit",
}
_MULTI_UNIT_CHAIN_NUMERIC_VALUE_TOKENS = {
"amount",
"appraisal",
"appraised",
"cost",
"costs",
"dollar",
"dollars",
"minimum",
"paid",
"price",
"prices",
"sale",
"sell",
"sold",
"total",
"value",
"valued",
"values",
"worth",
}
def _multi_unit_chain_numeric_signal(
unit_kind: str,
facet_type: str,
text: str,
tokens: set[str],
) -> float:
compact_text = _clean_text(text)
date_like_only = bool(
re.fullmatch(
r"(?:\d{4}[/-]\d{1,2}(?:[/-]\d{1,2})?|\d{1,2}[/-]\d{1,2}(?:[/-]\d{2,4})?|\d{4}/\d{2})",
compact_text,
)
)
if date_like_only:
return 0.0
signal = 0.0
if unit_kind == "numeric_quantity" or facet_type == "numeric":
signal += 0.55
if tokens & _MULTI_UNIT_CHAIN_NUMERIC_VALUE_TOKENS:
signal += 0.28
if re.search(r"(?:[$€£¥]\s*\d|\b\d+(?:,\d{3})*(?:\.\d+)?\s*(?:dollars?|usd|bucks?|yuan|rmb)\b)", text, flags=re.IGNORECASE):
signal += 0.72
elif re.search(r"\b\d+(?:,\d{3})*(?:\.\d+)?\s*(?:comments?|items?|pieces?|kits?|projects?|doctors?|weddings?|hours?|weeks?|days?|months?|years?|miles?)\b", text, flags=re.IGNORECASE):
signal += 0.42
return min(1.1, signal)
def _multi_unit_chain_date_like_numeric(text: str) -> bool:
compact_text = _clean_text(text)
return bool(
re.fullmatch(
r"(?:\d{4}[/-]\d{1,2}(?:[/-]\d{1,2})?|\d{1,2}[/-]\d{1,2}(?:[/-]\d{2,4})?|\d{4}/\d{2})",
compact_text,
)
)
def _multi_unit_chain_temporal_anchor_signal(text: str, tokens: set[str]) -> float:
signal = 0.0
if tokens & _UNIT_COVERAGE_TEMPORAL_TOKENS:
signal += 0.22
if re.search(
r"\b(?:about|around|roughly|few|several|couple|last|previous|earlier)\s+"
r"(?:a\s+)?(?:day|days|week|weeks|month|months|year|years)\s+ago\b|"
r"\b(?:yesterday|today|tomorrow|last\s+week|last\s+month|a\s+few\s+months\s+ago)\b",
text,
flags=re.IGNORECASE,
):
signal += 0.72
return min(1.0, signal)
_PROFILE_SHADOW_EVENTLIKE_SLOT_HINTS = {
"action",
"appointment",
"constraint",
"deadline",
"exchange",
"obligation",
"pickup",
"plan",
"preference",
"return",
"status",
"task",
}
_PROFILE_SHADOW_EVENTLIKE_TEXT_HINTS = {
"bought",
"buy",
"completed",
"did",
"exchange",
"exchanged",
"finish",
"finished",
"got",
"have to",
"need",
"needed",
"needs",
"paid",
"pick",
"picked",
"return",
"returned",
"should",
"still",
"took",
"went",
}
def _profile_shadow_eventlike_record(record: SessionMemoryRecordV2, metadata: Mapping[str, Any]) -> bool:
if _normalize(metadata.get("content_variant", "")) != "profile_shadow_from_writer":
return False
if record.state not in {"active", "parallel_active", "evidence"}:
return False
slot_text = " ".join(
[
_clean_text(record.category),
_clean_text(record.relation),
_clean_text(record.slot_key),
_clean_text(metadata.get("semantic_slot", "")),
_clean_text(metadata.get("profile_type", "")),
]
).lower()
value_text = " ".join(
[
_clean_text(record.value),
_clean_text(metadata.get("source_span", "")),
_clean_text(metadata.get("raw_text", "")),
]
).lower()
return bool(
any(hint in slot_text for hint in _PROFILE_SHADOW_EVENTLIKE_SLOT_HINTS)
or any(hint in value_text for hint in _PROFILE_SHADOW_EVENTLIKE_TEXT_HINTS)
)
def _profile_shadow_unit_text(record: SessionMemoryRecordV2, metadata: Mapping[str, Any]) -> str:
return " ".join(
[
_clean_text(record.value),
_clean_text(record.category),
_clean_text(record.relation),
_clean_text(record.slot_key),
_clean_text(metadata.get("semantic_slot", "")),
_clean_text(metadata.get("profile_type", "")),
_clean_text(metadata.get("profile_domain", "")),
_clean_text(metadata.get("source_span", "")),
_clean_text(metadata.get("raw_text", "")),
]
)
def _unit_coverage_pack_hits(
graph: SessionMemoryGraphV2,
query: str,
final_hits: Sequence[MemoryHit],
*,
top_k: int,
) -> Dict[str, Any]:
raw_query_tokens = set(_path_utility_tokens(query))
query_tokens = {token for token in raw_query_tokens if token not in _UNIT_COVERAGE_QUERY_DROP_TOKENS}
if not query_tokens:
return {"hits": list(final_hits), "metadata": {"unit_coverage_pack_enabled": False, "unit_coverage_reason": "empty_query"}}
count_intent = bool(query_tokens & _UNIT_COVERAGE_COUNT_TOKENS)
percentage_intent = bool(query_tokens & {"percent", "percentage"})
temporal_intent = bool(raw_query_tokens & _MULTI_UNIT_CHAIN_TEMPORAL_COMPARISON_TOKENS)
direct_unit_intent = not count_intent and not temporal_intent and len(query_tokens) >= 2
if not count_intent and not temporal_intent and not direct_unit_intent:
return {
"hits": list(final_hits),
"metadata": {"unit_coverage_pack_enabled": False, "unit_coverage_reason": "no_unit_intent"},
}
records = list(getattr(graph, "records_by_id", {}).values())
parent_by_slot = {_clean_text(record.slot_key).lower(): record for record in records}
candidates: List[tuple[float, SessionMemoryRecordV2, SessionMemoryRecordV2 | None, Dict[str, Any]]] = []
for record in records:
metadata = dict(record.metadata or {})
if record.state not in {"active", "parallel_active", "evidence"}:
continue
profile_shadow_unit = _profile_shadow_eventlike_record(record, metadata)
if profile_shadow_unit:
unit_kind = "profile_shadow_unit"
facet_type = "action"
parent_slot = _clean_text(record.slot_key).lower()
parent = None
parent_text = ""
unit_text = _profile_shadow_unit_text(record, metadata)
else:
if _normalize(metadata.get("content_variant", "")) != "event_facet_write":
continue
if _normalize(metadata.get("facet_layer_version", "")) not in {"event_unit_v1", "event_facet_v1"}:
continue
unit_kind = _normalize(metadata.get("unit_kind", ""))
facet_type = _normalize(metadata.get("facet_type", ""))
parent_slot = _clean_text(metadata.get("facet_parent_slot_key", "")).lower()
parent = parent_by_slot.get(parent_slot)
parent_text = " ".join(
[
_clean_text(parent.value if parent else ""),
_clean_text(dict(parent.metadata or {}).get("source_span", "") if parent else ""),
" ".join(parent.anchor_concepts if parent else []),
]
)
unit_text = " ".join(
[
_clean_text(record.value),
_clean_text(metadata.get("facet_role", "")),
_clean_text(metadata.get("unit_kind", "")),
_clean_text(metadata.get("action", "")),
_clean_text(metadata.get("target", "")),
_clean_text(metadata.get("quantity", "")),
_clean_text(metadata.get("unit", "")),
_clean_text(metadata.get("normalized_time", "")),
_clean_text(metadata.get("status", "")),
_clean_text(metadata.get("facet_source_span", "")),
parent_text,
]
)
unit_tokens = set(_path_utility_tokens(unit_text))
semantic_event_unit = _normalize(record.memory_id).startswith("tmcra.event.") or _normalize(parent_slot).startswith("tmcra.event.")
overlap = query_tokens & unit_tokens
score = 0.0
if overlap:
score += min(0.92, len(overlap) / max(1.0, len(query_tokens)) * 1.55)
if count_intent and facet_type in {"action", "entity", "numeric", "role", "evidence_role"}:
score += 0.42
if temporal_intent and facet_type in {"temporal", "action", "role", "entity"}:
score += 0.38
if direct_unit_intent and facet_type in {"action", "entity", "state", "numeric"}:
score += 0.30
if direct_unit_intent and len(overlap) >= 2:
score += 0.36
if unit_kind in {"action_unit", "target_entity", "leadership_role", "participation_role"}:
score += 0.16
if unit_kind in {"numeric_quantity", "temporal_anchor", "state_status"}:
score += 0.12
if profile_shadow_unit:
score += 0.28
if count_intent and facet_type in {"action", "entity"}:
score += 0.24
semantic_event_priority = bool(semantic_event_unit and (count_intent or percentage_intent))
if semantic_event_priority:
score += 0.46
if percentage_intent and (
unit_kind == "numeric_quantity"
or facet_type == "numeric"
or re.search(r"\b\d+(?:\.\d+)?\s*%|\b\d+(?:\.\d+)?\s*percent\b", unit_text, flags=re.IGNORECASE)
):
score += 0.72
if parent is not None:
score += 0.08
if score < 0.55:
continue
candidates.append(
(
round(min(3.2, 1.28 + score), 6),
record,
parent,
{
"unit_coverage_overlap_tokens": sorted(overlap)[:14],
"unit_coverage_score": round(score, 6),
"unit_kind": unit_kind,
"facet_type": facet_type,
"unit_coverage_count_intent": bool(count_intent),
"unit_coverage_percentage_intent": bool(percentage_intent),
"unit_coverage_temporal_intent": bool(temporal_intent),
"unit_coverage_direct_unit_intent": bool(direct_unit_intent),
"unit_coverage_profile_shadow_unit": bool(profile_shadow_unit),
"unit_coverage_semantic_event_unit": bool(semantic_event_unit),
"unit_coverage_semantic_event_priority": bool(semantic_event_priority),
},
)
)
if not candidates:
return {
"hits": list(final_hits),
"metadata": {
"unit_coverage_pack_enabled": True,
"unit_coverage_candidate_count": 0,
"unit_coverage_inserted_hit_count": 0,
"unit_coverage_count_intent": bool(count_intent),
"unit_coverage_percentage_intent": bool(percentage_intent),
"unit_coverage_temporal_intent": bool(temporal_intent),
"unit_coverage_direct_unit_intent": bool(direct_unit_intent),
},
}
candidates.sort(
key=lambda item: (
1 if bool(item[3].get("unit_coverage_semantic_event_priority", False)) else 0,
1
if bool(item[3].get("unit_coverage_percentage_intent"))
and (
_normalize(item[3].get("unit_kind", "")) == "numeric_quantity"
or _normalize(item[3].get("facet_type", "")) == "numeric"
)
else 0,
float(item[0]),
int(item[1].turn_index),
),
reverse=True,
)
selected: List[tuple[float, SessionMemoryRecordV2, SessionMemoryRecordV2 | None, Dict[str, Any]]] = []
seen_unit_values: set[str] = set()
try:
max_selected_units = max(1, int(os.getenv("TMCRA_UNIT_COVERAGE_PACK_MAX_UNITS", "6") or 6))
except (TypeError, ValueError):
max_selected_units = 6
for item in candidates:
_, record, parent, metadata = item
key = "|".join(
[
_normalize(metadata.get("unit_kind", "")),
_normalize(record.value),
_normalize(dict(record.metadata or {}).get("facet_parent_slot_key", "")),
]
)
if key in seen_unit_values:
continue
seen_unit_values.add(key)
selected.append(item)
if len(selected) >= max_selected_units:
break
if count_intent:
try:
max_profile_shadow_units = max(0, int(os.getenv("TMCRA_UNIT_COVERAGE_PROFILE_SHADOW_MAX_UNITS", "2") or 2))
except (TypeError, ValueError):
max_profile_shadow_units = 2
selected_ids = {item[1].memory_id for item in selected}
added_profile_shadow_units = 0
for item in candidates:
_, record, _parent, metadata = item
if added_profile_shadow_units >= max_profile_shadow_units:
break
if record.memory_id in selected_ids:
continue
if not bool(metadata.get("unit_coverage_profile_shadow_unit", False)):
continue
selected.append(item)
selected_ids.add(record.memory_id)
added_profile_shadow_units += 1
packed_hits: List[MemoryHit] = []
for score, unit_record, parent, extra in selected:
packed_hits.append(
_memory_hit_from_record(
unit_record,
score=score,
metadata={
**extra,
"unit_coverage_pack": True,
"evidence_snippet_role": "unit_coverage_evidence_unit",
},
)
)
if parent is not None:
unit_metadata = dict(unit_record.metadata or {})
parent_hit = _memory_hit_from_record(
parent,
score=max(1.12, score - 0.06),
metadata={
**extra,
"unit_coverage_pack": True,
"evidence_snippet_role": "unit_coverage_parent_event",
"unit_coverage_parent_memory_id": parent.memory_id,
"unit_coverage_child_id": unit_record.memory_id,
"unit_coverage_child_value": unit_record.value,
"unit_coverage_child_source_span": _clean_text(
unit_metadata.get("facet_source_span", "") or unit_metadata.get("source_span", "")
),
},
)
parent_hit.memory_id = f"{parent.memory_id}#unit_parent:{unit_record.memory_id}"
packed_hits.append(parent_hit)
merged: List[MemoryHit] = []
seen_ids: set[str] = set()
try:
insertion_index = max(0, min(len(final_hits), int(os.getenv("TMCRA_UNIT_COVERAGE_PACK_INSERT_AFTER", "2") or 2)))
except (TypeError, ValueError):
insertion_index = min(len(final_hits), 2)
ordered_hits = [*list(final_hits[:insertion_index]), *packed_hits, *list(final_hits[insertion_index:])]
for hit in ordered_hits:
if hit.memory_id and hit.memory_id in seen_ids:
continue
if hit.memory_id:
seen_ids.add(hit.memory_id)
merged.append(hit)
return {
"hits": merged,
"metadata": {
"unit_coverage_pack_enabled": True,
"unit_coverage_candidate_count": len(candidates),
"unit_coverage_selected_unit_count": len(selected),
"unit_coverage_inserted_hit_count": len(packed_hits),
"unit_coverage_count_intent": bool(count_intent),
"unit_coverage_percentage_intent": bool(percentage_intent),
"unit_coverage_temporal_intent": bool(temporal_intent),
"unit_coverage_direct_unit_intent": bool(direct_unit_intent),
},
}
def _multi_unit_chain_focus_tokens(query: str) -> set[str]:
tokens = set(_path_utility_tokens(query))
return {
normalized
for token in tokens
for normalized in [_multi_unit_chain_normalize_token(token)]
if normalized and normalized not in _MULTI_UNIT_CHAIN_FOCUS_DROP_TOKENS and len(normalized) > 2
}
def _multi_unit_chain_normalize_token(token: str) -> str:
raw = _normalize(token)
if not raw:
return ""
aliases = {
"bought": "buy",
"buying": "buy",
"purchased": "buy",
"purchasing": "buy",
"worked": "work",
"working": "work",
"started": "start",
"starting": "start",
"finished": "finish",
"finishing": "finish",
"completed": "complete",
"completing": "complete",
"returned": "return",
"returning": "return",
"met": "meet",
"meeting": "meet",
"picked": "pick",
"picking": "pick",
"led": "lead",
"leading": "lead",
"kits": "kit",
"items": "item",
"projects": "project",
"events": "event",
"clothes": "clothing",
}
if raw in aliases:
return aliases[raw]
if len(raw) > 4 and raw.endswith("ies"):
return raw[:-3] + "y"
if len(raw) > 4 and raw.endswith("ing"):
stem = raw[:-3]
if len(stem) > 3 and stem[-1] == stem[-2]:
stem = stem[:-1]
return stem
if len(raw) > 3 and raw.endswith("ed"):
stem = raw[:-2]
if len(stem) > 3 and stem[-1] == stem[-2]:
stem = stem[:-1]
return stem
if len(raw) > 3 and raw.endswith("s") and not raw.endswith("ss"):
return raw[:-1]
return raw
def _multi_unit_chain_normalized_tokens(text: str) -> set[str]:
return {
normalized
for token in _path_utility_tokens(text)
for normalized in [_multi_unit_chain_normalize_token(token)]
if normalized
}
def _multi_unit_chain_hit_text(record: SessionMemoryRecordV2, parent: SessionMemoryRecordV2 | None) -> str:
metadata = dict(record.metadata or {})
parent_metadata = dict(parent.metadata or {}) if parent is not None else {}
return " ".join(
[
_clean_text(record.value),
_clean_text(record.category),
_clean_text(record.relation),
_clean_text(metadata.get("facet_type", "")),
_clean_text(metadata.get("unit_kind", "")),
_clean_text(metadata.get("facet_role", "")),
_clean_text(metadata.get("facet_value", "")),
_clean_text(metadata.get("facet_source_span", "")),
_clean_text(metadata.get("action", "")),
_clean_text(metadata.get("target", "")),
_clean_text(metadata.get("quantity", "")),
_clean_text(metadata.get("status", "")),
_clean_text(parent.value if parent else ""),
_clean_text(parent_metadata.get("source_span", "")),
_clean_text(parent_metadata.get("raw_text", "")),
]
)
def _multi_unit_chain_local_hit_text(record: SessionMemoryRecordV2, parent: SessionMemoryRecordV2 | None) -> str:
metadata = dict(record.metadata or {})
parent_metadata = dict(parent.metadata or {}) if parent is not None else {}
return " ".join(
[
_clean_text(record.value),
_clean_text(record.category),
_clean_text(record.relation),
_clean_text(metadata.get("facet_type", "")),
_clean_text(metadata.get("unit_kind", "")),
_clean_text(metadata.get("facet_role", "")),
_clean_text(metadata.get("facet_value", "")),
_clean_text(metadata.get("facet_source_span", "")),
_clean_text(metadata.get("action", "")),
_clean_text(metadata.get("target", "")),
_clean_text(metadata.get("quantity", "")),
_clean_text(metadata.get("status", "")),
_clean_text(parent.value if parent else ""),
_clean_text(parent_metadata.get("source_span", "")),
]
)
def _multi_unit_chain_slot_hits(
graph: SessionMemoryGraphV2,
query: str,
final_hits: Sequence[MemoryHit],
*,
top_k: int,
) -> Dict[str, Any]:
mode = _normalize(os.getenv("TMCRA_MULTI_UNIT_CHAIN_SLOT_MODE", "on"))
if mode in _MULTI_UNIT_CHAIN_DISABLED_MODES:
return {
"hits": list(final_hits),
"metadata": {"multi_unit_chain_slot_enabled": False, "multi_unit_chain_slot_reason": "disabled"},
}
query_tokens = set(_path_utility_tokens(query))
focus_tokens = _multi_unit_chain_focus_tokens(query)
temporal_comparison_intent = bool(query_tokens & _MULTI_UNIT_CHAIN_TEMPORAL_COMPARISON_TOKENS) and len(focus_tokens) >= 2
aggregation_or_joiner_intent = bool(
re.search(r"\b(?:and|or|both|each|total|sum|minimum|maximum|amount|count|many|number)\b", str(query), flags=re.IGNORECASE)
)
numeric_aggregation_intent = bool(
query_tokens
& (
_MULTI_UNIT_CHAIN_NUMERIC_VALUE_TOKENS
| {"comment", "comments", "number", "percent", "percentage", "sum", "total"}
)
)
count_or_aggregation_intent = bool(query_tokens & _MULTI_UNIT_CHAIN_COUNT_TOKENS) and (aggregation_or_joiner_intent or len(focus_tokens) >= 3)
multi_intent = count_or_aggregation_intent or temporal_comparison_intent
if not query_tokens or not multi_intent:
return {
"hits": list(final_hits),
"metadata": {
"multi_unit_chain_slot_enabled": False,
"multi_unit_chain_slot_reason": "no_multi_intent",
},
}
if not focus_tokens:
return {
"hits": list(final_hits),
"metadata": {
"multi_unit_chain_slot_enabled": True,
"multi_unit_chain_slot_formed": False,
"multi_unit_chain_slot_reason": "no_focus_tokens",
},
}
records = list(getattr(graph, "records_by_id", {}).values())
parent_by_slot = {_clean_text(record.slot_key).lower(): record for record in records}
candidates: List[tuple[float, SessionMemoryRecordV2, SessionMemoryRecordV2 | None, Dict[str, Any]]] = []
for record in records:
metadata = dict(record.metadata or {})
if record.state not in {"active", "parallel_active", "evidence"}:
continue
profile_shadow_unit = _profile_shadow_eventlike_record(record, metadata)
if profile_shadow_unit:
facet_type = "action"
unit_kind = "profile_shadow_unit"
parent_slot = _clean_text(record.slot_key).lower()
parent = None
text = _profile_shadow_unit_text(record, metadata)
local_text = text
else:
if _normalize(metadata.get("content_variant", "")) != "event_facet_write":
continue
if _normalize(metadata.get("facet_layer_version", "")) != "event_unit_v1":
continue
facet_type = _normalize(metadata.get("facet_type", ""))
unit_kind = _normalize(metadata.get("unit_kind", ""))
if facet_type not in _MULTI_UNIT_CHAIN_FACET_TYPES and unit_kind not in _MULTI_UNIT_CHAIN_UNIT_KINDS:
continue
parent_slot = _clean_text(metadata.get("facet_parent_slot_key", "")).lower()
parent = parent_by_slot.get(parent_slot)
text = _multi_unit_chain_hit_text(record, parent)
local_text = _multi_unit_chain_local_hit_text(record, parent)
semantic_event_unit = _normalize(record.memory_id).startswith("tmcra.event.") or _normalize(parent_slot).startswith("tmcra.event.")
unit_tokens = _multi_unit_chain_normalized_tokens(text)
local_tokens = _multi_unit_chain_normalized_tokens(local_text)
focus_overlap = focus_tokens & unit_tokens
if not focus_overlap:
continue
score = min(1.2, len(focus_overlap) / max(1.0, len(focus_tokens)) * 1.6)
local_numeric_signal = _multi_unit_chain_numeric_signal(unit_kind, facet_type, local_text, local_tokens)
if _multi_unit_chain_date_like_numeric(local_text):
numeric_signal = 0.0
else:
numeric_signal = local_numeric_signal
if numeric_signal <= 0.0:
numeric_signal = _multi_unit_chain_numeric_signal(unit_kind, facet_type, text, unit_tokens)
temporal_anchor_signal = _multi_unit_chain_temporal_anchor_signal(text, unit_tokens)
local_temporal_anchor_signal = _multi_unit_chain_temporal_anchor_signal(local_text, local_tokens)
if facet_type in {"action", "entity", "role", "numeric"}:
score += 0.34
if unit_kind in {"action_unit", "target_entity", "leadership_role", "participation_role", "numeric_quantity"}:
score += 0.24
if profile_shadow_unit:
score += 0.30
semantic_event_priority = bool(semantic_event_unit and (count_or_aggregation_intent or numeric_aggregation_intent))
if semantic_event_priority:
score += 0.48
if numeric_signal and numeric_aggregation_intent:
score += numeric_signal
if count_or_aggregation_intent and not numeric_aggregation_intent and facet_type in {"action", "entity", "role"}:
score += 0.36
if count_or_aggregation_intent and not numeric_aggregation_intent and unit_kind in {"action_unit", "target_entity", "leadership_role", "participation_role"}:
score += 0.22
if parent is not None:
score += 0.12
if score < 0.72:
continue
candidates.append(
(
round(min(4.0, 2.15 + score), 6),
record,
parent,
{
"multi_unit_chain_focus_overlap_tokens": sorted(focus_overlap)[:16],
"multi_unit_chain_score": round(score, 6),
"multi_unit_chain_numeric_signal": round(numeric_signal, 6),
"multi_unit_chain_temporal_anchor_signal": round(temporal_anchor_signal, 6),
"multi_unit_chain_local_temporal_anchor_signal": round(local_temporal_anchor_signal, 6),
"multi_unit_chain_temporal_comparison": bool(temporal_comparison_intent),
"multi_unit_chain_numeric_aggregation": bool(numeric_aggregation_intent),
"unit_kind": unit_kind,
"facet_type": facet_type,
"multi_unit_chain_parent_slot_key": parent_slot,
"multi_unit_chain_profile_shadow_unit": bool(profile_shadow_unit),
"multi_unit_chain_semantic_event_unit": bool(semantic_event_unit),
"multi_unit_chain_semantic_event_priority": bool(semantic_event_priority),
"multi_unit_chain_count_or_aggregation": bool(count_or_aggregation_intent),
},
)
)
if not candidates:
return {
"hits": list(final_hits),
"metadata": {
"multi_unit_chain_slot_enabled": True,
"multi_unit_chain_slot_formed": False,
"multi_unit_chain_slot_reason": "no_matching_units",
"multi_unit_chain_candidate_count": 0,
"multi_unit_chain_focus_tokens": sorted(focus_tokens)[:24],
},
}
if temporal_comparison_intent:
candidates.sort(
key=lambda item: (
float(item[3].get("multi_unit_chain_local_temporal_anchor_signal", 0.0) or 0.0),
float(item[3].get("multi_unit_chain_temporal_anchor_signal", 0.0) or 0.0),
1 if bool(item[3].get("multi_unit_chain_semantic_event_priority", False)) else 0,
0
if _normalize(item[3].get("unit_kind", "")) == "numeric_quantity"
or _normalize(item[3].get("facet_type", "")) == "numeric"
else 1,
float(item[0]),
int(item[1].turn_index),
),
reverse=True,
)
else:
candidates.sort(
key=lambda item: (
float(item[3].get("multi_unit_chain_numeric_signal", 0.0) or 0.0)
if bool(item[3].get("multi_unit_chain_numeric_aggregation", False))
else (
1.2
if bool(item[3].get("multi_unit_chain_semantic_event_priority", False))
else
0.0
if _normalize(item[3].get("unit_kind", "")) == "numeric_quantity"
or _normalize(item[3].get("facet_type", "")) == "numeric"
else 1.0
),
float(item[0]),
int(item[1].turn_index),
),
reverse=True,
)
selected: List[tuple[float, SessionMemoryRecordV2, SessionMemoryRecordV2 | None, Dict[str, Any]]] = []
seen_parent_slots: set[str] = set()
seen_values: set[str] = set()
max_units = max(2, min(8, int(os.getenv("TMCRA_MULTI_UNIT_CHAIN_SLOT_MAX_UNITS", "6") or 6)))
for item in candidates:
_, record, parent, metadata = item
parent_slot = _clean_text(metadata.get("multi_unit_chain_parent_slot_key", ""))
value_key = "|".join(
[
_normalize(record.value),
_normalize(metadata.get("unit_kind", "")),
parent_slot,
]
)
if value_key in seen_values:
continue
if parent_slot and parent_slot in seen_parent_slots:
# Keep the chain broad: one strongest unit per parent event first.
continue
seen_values.add(value_key)
if parent_slot:
seen_parent_slots.add(parent_slot)
selected.append(item)
if len(selected) >= max_units:
break
if count_or_aggregation_intent:
try:
max_profile_shadow_units = max(0, int(os.getenv("TMCRA_MULTI_UNIT_CHAIN_PROFILE_SHADOW_MAX_UNITS", "2") or 2))
except (TypeError, ValueError):
max_profile_shadow_units = 2
selected_ids = {item[1].memory_id for item in selected}
added_profile_shadow_units = 0
for item in candidates:
_, record, _parent, metadata = item
if added_profile_shadow_units >= max_profile_shadow_units:
break
if record.memory_id in selected_ids:
continue
if not bool(metadata.get("multi_unit_chain_profile_shadow_unit", False)):
continue
selected.append(item)
selected_ids.add(record.memory_id)
parent_slot = _clean_text(metadata.get("multi_unit_chain_parent_slot_key", ""))
if parent_slot:
seen_parent_slots.add(parent_slot)
added_profile_shadow_units += 1
min_parents = max(2, min(4, int(os.getenv("TMCRA_MULTI_UNIT_CHAIN_SLOT_MIN_PARENTS", "2") or 2)))
if len(seen_parent_slots) < min_parents:
return {
"hits": list(final_hits),
"metadata": {
"multi_unit_chain_slot_enabled": True,
"multi_unit_chain_slot_formed": False,
"multi_unit_chain_slot_reason": "insufficient_parent_coverage",
"multi_unit_chain_candidate_count": len(candidates),
"multi_unit_chain_selected_unit_count": len(selected),
"multi_unit_chain_parent_count": len(seen_parent_slots),
"multi_unit_chain_focus_tokens": sorted(focus_tokens)[:24],
},
}
chain_memory_ids: List[str] = []
packed_hits: List[MemoryHit] = []
for score, unit_record, parent, extra in selected:
chain_memory_ids.append(unit_record.memory_id)
packed_hits.append(
_memory_hit_from_record(
unit_record,
score=score,
metadata={
**extra,
"multi_unit_chain_slot": True,
"multi_unit_chain_bundle": True,
"evidence_snippet_role": "multi_unit_chain_evidence_unit",
},
)
)
if parent is not None:
unit_metadata = dict(unit_record.metadata or {})
chain_memory_ids.append(parent.memory_id)
parent_hit = _memory_hit_from_record(
parent,
score=max(1.9, score - 0.08),
metadata={
**extra,
"multi_unit_chain_slot": True,
"multi_unit_chain_bundle": True,
"evidence_snippet_role": "multi_unit_chain_parent_event",
"multi_unit_chain_child_id": unit_record.memory_id,
"multi_unit_chain_child_value": unit_record.value,
"multi_unit_chain_child_source_span": _clean_text(
unit_metadata.get("facet_source_span", "") or unit_metadata.get("source_span", "")
),
},
)
parent_hit.memory_id = f"{parent.memory_id}#multi_parent:{unit_record.memory_id}"
packed_hits.append(parent_hit)
seen_ids: set[str] = set()
merged: List[MemoryHit] = []
insertion_index = max(0, min(len(final_hits), int(os.getenv("TMCRA_MULTI_UNIT_CHAIN_SLOT_INSERT_AFTER", "1") or 1)))
ordered = [*list(final_hits[:insertion_index]), *packed_hits, *list(final_hits[insertion_index:])]
for hit in ordered:
if hit.memory_id and hit.memory_id in seen_ids:
continue
if hit.memory_id:
seen_ids.add(hit.memory_id)
merged.append(hit)
return {
"hits": merged,
"metadata": {
"multi_unit_chain_slot_enabled": True,
"multi_unit_chain_slot_formed": True,
"multi_unit_chain_slot_reason": "formed",
"multi_unit_chain_candidate_count": len(candidates),
"multi_unit_chain_selected_unit_count": len(selected),
"multi_unit_chain_parent_count": len(seen_parent_slots),
"multi_unit_chain_inserted_hit_count": len(packed_hits),
"multi_unit_chain_memory_ids": _dedupe(chain_memory_ids, max_items=32),
"multi_unit_chain_focus_tokens": sorted(focus_tokens)[:24],
},
}
def _profile_support_source_hits(
graph: SessionMemoryGraphV2,
profile_hit: MemoryHit,
*,
grouped_hits: Mapping[str, Sequence[MemoryHit]],
query: str,
limit: int,
) -> List[MemoryHit]:
metadata = dict(profile_hit.metadata or {})
support_ids = _dedupe(
[
*list(metadata.get("profile_support_ids", []) or []),
*list(metadata.get("support_memory_ids", []) or []),
],
max_items=24,
)
if not support_ids:
return []
query_raw_tokens = set(_path_utility_tokens(query))
query_tokens = _profile_query_expanded_tokens(query)
profile_event_id = _clean_text(metadata.get("profile_first_hybrid_event_id", "")) or _runtime_event_key(profile_hit)
candidates: List[tuple[float, MemoryHit]] = []
for support_id in support_ids:
record = getattr(graph, "records_by_id", {}).get(support_id)
if record is None or record.state not in {"active", "parallel_active", "evidence"}:
continue
support_hit = _memory_hit_from_record(record)
event_id = _runtime_event_key(support_hit)
if not event_id or event_id == profile_event_id or event_id not in grouped_hits:
continue
match_score, overlap_tokens, raw_overlap_tokens = _profile_hit_match_score(query_raw_tokens, query_tokens, support_hit)
source_score = max(0.80, float(profile_hit.score) - 0.02) + min(0.28, match_score * 0.12)
support_metadata = dict(support_hit.metadata or {})
profile_metadata = dict(profile_hit.metadata or {})
support_metadata.update(
{
"profile_first_hybrid_rescue": True,
"profile_first_source_support": True,
"profile_first_parent_memory_id": profile_hit.memory_id,
"profile_first_hybrid_event_id": event_id,
"profile_first_parent_summary": _clean_text(profile_metadata.get("profile_summary", "")) or _clean_text(profile_metadata.get("profile_value", "")) or _clean_text(profile_hit.value),
"profile_type": _clean_text(support_metadata.get("profile_type", "")) or _clean_text(profile_metadata.get("profile_type", "")),
"profile_domain": _clean_text(support_metadata.get("profile_domain", "")) or _clean_text(profile_metadata.get("profile_domain", "")),
"profile_domain_label": _clean_text(support_metadata.get("profile_domain_label", "")) or _clean_text(profile_metadata.get("profile_domain_label", "")),
"profile_value": _clean_text(support_metadata.get("profile_value", "")) or _clean_text(profile_metadata.get("profile_value", "")) or _clean_text(profile_hit.value),
"profile_summary": _clean_text(support_metadata.get("profile_summary", "")) or _clean_text(profile_metadata.get("profile_summary", "")) or _clean_text(profile_hit.value),
"profile_query_match_score": round(match_score, 6),
"profile_query_overlap_tokens": list(overlap_tokens),
"profile_query_raw_overlap_tokens": list(raw_overlap_tokens),
"profile_source_pack_role": "source_event_support",
}
)
candidates.append(
(
source_score,
MemoryHit(
memory_id=support_hit.memory_id,
category=support_hit.category,
value=support_hit.value,
relation=support_hit.relation,
anchors=list(support_hit.anchors),
score=round(source_score, 6),
source_kind=support_hit.source_kind,
slot_key=support_hit.slot_key,
state=support_hit.state,
turn_index=int(support_hit.turn_index),
metadata=support_metadata,
),
)
)
candidates.sort(
key=lambda item: (
bool((item[1].metadata or {}).get("profile_query_raw_overlap_tokens")),
float(item[0]),
int(item[1].turn_index),
),
reverse=True,
)
return [hit for _, hit in candidates[: max(0, int(limit))]]
def _profile_same_event_source_hit(profile_hit: MemoryHit, group_hits: Sequence[MemoryHit], *, event_id: str) -> MemoryHit | None:
source_hit = _support_hit_for_path("speaker_event_source_turn", group_hits) or _representative_event_hit(group_hits)
if source_hit is None or source_hit.memory_id == profile_hit.memory_id:
return None
metadata = dict(source_hit.metadata or {})
profile_metadata = dict(profile_hit.metadata or {})
metadata.update(
{
"profile_first_hybrid_rescue": True,
"profile_first_source_support": True,
"profile_first_same_event_support": True,
"profile_first_parent_memory_id": profile_hit.memory_id,
"profile_first_hybrid_event_id": event_id,
"profile_first_parent_summary": _clean_text(profile_metadata.get("profile_summary", "")) or _clean_text(profile_metadata.get("profile_value", "")) or _clean_text(profile_hit.value),
"profile_type": _clean_text(metadata.get("profile_type", "")) or _clean_text(profile_metadata.get("profile_type", "")),
"profile_domain": _clean_text(metadata.get("profile_domain", "")) or _clean_text(profile_metadata.get("profile_domain", "")),
"profile_domain_label": _clean_text(metadata.get("profile_domain_label", "")) or _clean_text(profile_metadata.get("profile_domain_label", "")),
"profile_value": _clean_text(metadata.get("profile_value", "")) or _clean_text(profile_metadata.get("profile_value", "")) or _clean_text(profile_hit.value),
"profile_summary": _clean_text(metadata.get("profile_summary", "")) or _clean_text(profile_metadata.get("profile_summary", "")) or _clean_text(profile_hit.value),
"profile_source_pack_role": "same_event_source_turn",
"evidence_snippet_role": "profile_source_support",
}
)
return MemoryHit(
memory_id=source_hit.memory_id,
category=source_hit.category,
value=source_hit.value,
relation=source_hit.relation,
anchors=list(source_hit.anchors),
score=max(float(source_hit.score), float(profile_hit.score) - 0.01),
source_kind=source_hit.source_kind,
slot_key=source_hit.slot_key,
state=source_hit.state,
turn_index=int(source_hit.turn_index),
metadata=metadata,
)
def _profile_first_hybrid_rescue(
graph: SessionMemoryGraphV2,
query: str,
*,
grouped_hits: Mapping[str, Sequence[MemoryHit]],
top_k: int,
) -> Dict[str, Any]:
intent = infer_profile_query_intent(query)
if not bool(intent.get("enabled")):
return {"hits": [], "event_ids": []}
raw_hits = _profile_query_rescue_hits(graph, query, top_k=max(4, int(top_k or 1) * 3))
limit = max(2, min(4, int(top_k or 1)))
selected_hits: List[MemoryHit] = []
selected_event_ids: List[str] = []
seen_event_ids: set[str] = set()
for rank, hit in enumerate(raw_hits, start=1):
event_id = _runtime_event_key(hit)
if not event_id or event_id not in grouped_hits or event_id in seen_event_ids:
continue
metadata = dict(hit.metadata or {})
metadata.update(
{
"profile_first_hybrid_rescue": True,
"profile_first_hybrid_rank": rank,
"profile_first_hybrid_event_id": event_id,
}
)
hit.metadata = metadata
hit.score = round(max(float(hit.score), 0.88) + max(0.0, 0.08 - (0.01 * len(selected_hits))), 6)
selected_hits.append(hit)
selected_event_ids.append(event_id)
seen_event_ids.add(event_id)
same_event_source = _profile_same_event_source_hit(hit, grouped_hits.get(event_id, []), event_id=event_id)
if same_event_source is not None:
selected_hits.append(same_event_source)
if len(selected_hits) >= limit:
break
support_hits = _profile_support_source_hits(
graph,
hit,
grouped_hits=grouped_hits,
query=query,
limit=max(1, min(3, limit - len(selected_hits) + 1)),
)
for support_hit in support_hits:
support_event_id = _clean_text((support_hit.metadata or {}).get("profile_first_hybrid_event_id", "")) or _runtime_event_key(support_hit)
if not support_event_id or support_event_id in seen_event_ids:
continue
selected_hits.append(support_hit)
selected_event_ids.append(support_event_id)
seen_event_ids.add(support_event_id)
if len(selected_hits) >= limit:
break
if len(selected_hits) >= limit:
break
return {
"hits": selected_hits,
"event_ids": selected_event_ids,
"memory_ids": [hit.memory_id for hit in selected_hits],
}
def _inject_profile_first_hits(
final_hits: Sequence[MemoryHit],
profile_first_hits: Sequence[MemoryHit],
*,
selected_event_ids: Sequence[str],
selected_path_ids: Sequence[str],
) -> List[MemoryHit]:
merged: List[MemoryHit] = []
seen_memory_ids: set[str] = set()
for hit in list(profile_first_hits) + list(final_hits):
if not hit or not hit.memory_id or hit.memory_id in seen_memory_ids:
continue
metadata = dict(hit.metadata or {})
if bool(metadata.get("profile_first_hybrid_rescue")):
event_id = _clean_text(metadata.get("profile_first_hybrid_event_id", "")) or _runtime_event_key(hit)
metadata.update(
{
"event_id": event_id,
"path_id": "",
"evidence_snippet_role": "selected_event_representative",
"hybrid_score_source": "profile_first_hybrid_rescue",
"selected_event_ids": list(selected_event_ids),
"selected_path_ids": list(selected_path_ids),
}
)
hit = MemoryHit(
memory_id=hit.memory_id,
category=hit.category,
value=hit.value,
relation=hit.relation,
anchors=list(hit.anchors),
score=float(hit.score),
source_kind=hit.source_kind,
slot_key=hit.slot_key,
state=hit.state,
turn_index=int(hit.turn_index),
metadata=metadata,
)
merged.append(hit)
seen_memory_ids.add(hit.memory_id)
return merged
def _profile_focused_pack_hits(
graph: SessionMemoryGraphV2,
query: str,
final_hits: Sequence[MemoryHit],
*,
top_k: int,
) -> Dict[str, Any]:
intent = infer_profile_query_intent(query)
if not bool(intent.get("enabled")):
return {
"hits": list(final_hits),
"metadata": {
"profile_focused_pack_enabled": False,
"profile_focused_pack_reason": "profile_intent_not_requested",
},
}
source_hits = _learnable_graph_hits(graph)
if not source_hits:
return {
"hits": list(final_hits),
"metadata": {
"profile_focused_pack_enabled": False,
"profile_focused_pack_reason": "no_learnable_hits",
},
}
runtime_graph = _build_runtime_graph_from_hits(query, source_hits)
grouped_hits = dict(runtime_graph.get("grouped_hits", {}) or {})
profile_first_payload = _profile_first_hybrid_rescue(
graph,
query,
grouped_hits=grouped_hits,
top_k=max(2, min(max(1, int(top_k or 1)), 4)),
)
profile_first_hits = list(profile_first_payload.get("hits", []) or [])
profile_first_event_ids = list(profile_first_payload.get("event_ids", []) or [])
profile_first_memory_ids = list(profile_first_payload.get("memory_ids", []) or [])
if not profile_first_hits:
return {
"hits": list(final_hits),
"metadata": {
"profile_focused_pack_enabled": True,
"profile_focused_pack_reason": "no_profile_hits",
"profile_focused_pack_event_ids": [],
"profile_focused_pack_memory_ids": [],
},
}
selected_event_ids = _dedupe([*profile_first_event_ids, *_event_ids_from_hits(final_hits)], max_items=max(1, int(top_k or 1) * 2))
selected_path_ids = _dedupe(
[
_clean_text((hit.metadata or {}).get("path_id", ""))
for hit in final_hits
if _clean_text((hit.metadata or {}).get("path_id", ""))
],
max_items=max(1, int(top_k or 1)),
)
merged_hits = _inject_profile_first_hits(
final_hits,
profile_first_hits,
selected_event_ids=selected_event_ids,
selected_path_ids=selected_path_ids,
)
merged_hits = _coverage_preserving_final_hits(merged_hits, selected_event_ids=selected_event_ids, top_k=top_k)
return {
"hits": merged_hits,
"metadata": {
"profile_focused_pack_enabled": True,
"profile_focused_pack_reason": "profile_first_pack_injected",
"profile_focused_pack_event_ids": list(profile_first_event_ids),
"profile_focused_pack_memory_ids": list(profile_first_memory_ids),
"profile_focused_pack_hit_count": len(profile_first_hits),
"profile_first_hybrid_enabled": True,
"profile_first_event_ids": list(profile_first_event_ids),
"profile_first_memory_ids": list(profile_first_memory_ids),
},
}
def _topic_bucket_rerank_hits(
graph: SessionMemoryGraphV2,
query: str,
hits: Sequence[MemoryHit],
*,
top_k: int,
) -> Dict[str, Any]:
if not hits:
profile_rescue_hits = _profile_query_rescue_hits(graph, query, top_k=top_k)
if profile_rescue_hits:
return {
"hits": profile_rescue_hits[: max(1, int(top_k or 1))],
"metadata": {
"topic_bucket_rerank_enabled": True,
"topic_bucket_rerank_reason": "profile_query_rescue_from_empty_hits",
"topic_bucket_profile_query_rescue_count": len(profile_rescue_hits),
},
}
return {
"hits": [],
"metadata": {"topic_bucket_rerank_enabled": True, "topic_bucket_rerank_reason": "no_hits"},
}
query_topic = _assign_topic_bucket_for_text(graph, query, turn_index=0, create=False)
query_bucket_id = _clean_text(query_topic.get("topic_bucket_id", ""))
adjacent_ids = _topic_adjacent_bucket_ids(graph, query_bucket_id)
dialogue_tunnel_ids = _dialogue_tunnel_bucket_ids(graph, query_bucket_id)
query_keywords = list(query_topic.get("topic_keywords", []) or [])
bridge_requested = _topic_bridge_requested(query)
dialogue_requested = _dialogue_tunnel_requested(query)
profile_query_requested = bool(infer_profile_query_intent(query).get("enabled"))
reranked: List[MemoryHit] = []
stats = {
"same_bucket": 0,
"bridge_bucket": 0,
"blocked_bridge_bucket": 0,
"dialogue_tunnel_bucket": 0,
"blocked_dialogue_tunnel_bucket": 0,
"profile_route_preserved": 0,
"profile_query_rescue": 0,
"overlap_bucket": 0,
"off_topic": 0,
}
for hit in list(hits):
metadata = dict(hit.metadata or {})
hit_bucket_id = _clean_text(metadata.get("topic_bucket_id", ""))
hit_keywords = _dedupe(metadata.get("topic_keywords", []) or [], max_items=32)
overlap = _topic_bucket_overlap_score(query_keywords, hit_keywords)
same_bucket = bool(query_bucket_id and hit_bucket_id and query_bucket_id == hit_bucket_id)
bridge_bucket = bool(hit_bucket_id and hit_bucket_id in adjacent_ids)
dialogue_tunnel_bucket = bool(hit_bucket_id and hit_bucket_id in dialogue_tunnel_ids)
overlap_bucket = overlap >= 0.22
bridge_allowed = bridge_bucket and bridge_requested
dialogue_allowed = dialogue_tunnel_bucket and dialogue_requested
profile_route_preserved = bool(
profile_query_requested
and (
bool(metadata.get("profile_layer"))
or "profile_route" in _normalize(metadata.get("match_reason", ""))
)
)
current_subject_preserved = bool(metadata.get("current_subject_resolver") or metadata.get("public_subject_match"))
delta = 0.0
if same_bucket:
delta += 1.35
stats["same_bucket"] += 1
elif current_subject_preserved:
delta += 0.18
stats["profile_route_preserved"] += 1
elif bridge_allowed:
delta += 0.42
stats["bridge_bucket"] += 1
elif dialogue_allowed:
delta += 0.16
stats["dialogue_tunnel_bucket"] += 1
elif overlap_bucket:
delta += 0.28 + overlap
stats["overlap_bucket"] += 1
elif bridge_bucket:
delta -= 1.05
stats["blocked_bridge_bucket"] += 1
elif dialogue_tunnel_bucket:
delta -= 1.35
stats["blocked_dialogue_tunnel_bucket"] += 1
elif profile_route_preserved:
delta += 0.08
stats["profile_route_preserved"] += 1
else:
delta -= 0.72
stats["off_topic"] += 1
memory_type = _normalize(metadata.get("memory_type", ""))
durability = _normalize(metadata.get("durability", ""))
conflict_policy = _normalize(metadata.get("conflict_policy", ""))
if (memory_type == "hard_constraint" or durability == "hard" or conflict_policy == "must_preserve") and (
same_bucket or bridge_allowed or dialogue_allowed or overlap_bucket
):
delta += 1.10
match_reason = _clean_text(metadata.get("match_reason", ""))
if profile_route_preserved and "profile_route" not in _normalize(match_reason):
match_reason = ",".join(_dedupe([match_reason, "profile_route"], max_items=4))
metadata.update(
{
"topic_bucket_rerank": True,
"topic_bucket_query_id": query_bucket_id,
"topic_bucket_query_label": _clean_text(query_topic.get("topic_label", "")),
"topic_bucket_overlap": round(overlap, 6),
"topic_bucket_delta": round(delta, 6),
"topic_bucket_same": same_bucket,
"topic_bucket_bridge": bridge_bucket,
"topic_bucket_bridge_allowed": bridge_allowed,
"topic_bucket_dialogue_tunnel": dialogue_tunnel_bucket,
"topic_bucket_dialogue_tunnel_allowed": dialogue_allowed,
"topic_bucket_profile_route_preserved": profile_route_preserved,
"topic_bucket_current_subject_preserved": current_subject_preserved,
"match_reason": match_reason,
}
)
hit.metadata = metadata
hit.score = float(hit.score) + delta
reranked.append(hit)
seen_ids = {hit.memory_id for hit in reranked if hit.memory_id}
rescue_records = [
record
for record in getattr(graph, "records_by_id", {}).values()
if record.memory_id not in seen_ids
and record.state == "active"
and _clean_text((record.metadata or {}).get("topic_bucket_id", "")) == query_bucket_id
and _normalize(record.category) != "question"
]
rescue_records.sort(
key=lambda record: (
int(any(marker in _normalize(record.value) for marker in ("过敏", "必须", "避开", "禁止", "不能", "must", "avoid", "allergy"))),
int(_normalize(record.category) in {"constraint", "preference"}),
float(record.confidence),
float(record.salience),
-int(record.turn_index),
),
reverse=True,
)
rescue_hits = [
_topic_bucket_record_to_hit(record, query_topic=query_topic, rank=index)
for index, record in enumerate(rescue_records[: max(2, min(12, int(top_k or 1) * 2))], start=1)
]
reranked.extend(rescue_hits)
dialogue_rescue_hits: List[MemoryHit] = []
if dialogue_requested and dialogue_tunnel_ids:
seen_ids.update(hit.memory_id for hit in rescue_hits if hit.memory_id)
dialogue_records = [
record
for record in getattr(graph, "records_by_id", {}).values()
if record.memory_id not in seen_ids
and record.state == "active"
and _clean_text((record.metadata or {}).get("topic_bucket_id", "")) in dialogue_tunnel_ids
and _normalize(record.category) != "question"
]
def dialogue_record_rank(record: SessionMemoryRecordV2) -> tuple[int, int, float, float, int]:
return (
int(any(marker in _normalize(record.value) for marker in ("过敏", "必须", "避开", "禁止", "不能", "must", "avoid", "allergy"))),
int(_normalize(record.category) in {"constraint", "preference"}),
float(record.confidence),
float(record.salience),
-int(record.turn_index),
)
dialogue_records.sort(key=dialogue_record_rank, reverse=True)
selected_dialogue_records: List[SessionMemoryRecordV2] = []
selected_dialogue_ids: set[str] = set()
for bucket_id in sorted(dialogue_tunnel_ids):
bucket_records = [
record
for record in dialogue_records
if _clean_text((record.metadata or {}).get("topic_bucket_id", "")) == bucket_id
]
if not bucket_records:
continue
best = bucket_records[0]
selected_dialogue_records.append(best)
selected_dialogue_ids.add(best.memory_id)
target_dialogue_rescue = max(len(selected_dialogue_records), max(1, min(4, int(top_k or 1) // 2 or 1)))
for record in dialogue_records:
if len(selected_dialogue_records) >= target_dialogue_rescue:
break
if record.memory_id in selected_dialogue_ids:
continue
selected_dialogue_records.append(record)
selected_dialogue_ids.add(record.memory_id)
dialogue_rescue_hits = [
_topic_bucket_record_to_hit(record, query_topic=query_topic, rank=index, rescue_kind="dialogue_tunnel")
for index, record in enumerate(selected_dialogue_records, start=1)
]
reranked.extend(dialogue_rescue_hits)
profile_rescue_hits: List[MemoryHit] = []
if profile_query_requested:
seen_ids.update(hit.memory_id for hit in reranked if hit.memory_id)
for index, hit in enumerate(_profile_query_rescue_hits(graph, query, top_k=top_k), start=1):
if hit.memory_id in seen_ids:
continue
metadata = dict(hit.metadata or {})
metadata.update(
{
"profile_query_rescue_rank": index,
"topic_bucket_query_id": query_bucket_id,
"topic_bucket_query_label": _clean_text(query_topic.get("topic_label", "")),
"topic_bucket_profile_route_preserved": True,
}
)
hit.metadata = metadata
hit.score = float(hit.score) + 0.08
profile_rescue_hits.append(hit)
seen_ids.add(hit.memory_id)
stats["profile_query_rescue"] = len(profile_rescue_hits)
reranked.extend(profile_rescue_hits)
reranked.sort(key=lambda item: (float(item.score), int(item.turn_index)), reverse=True)
limit = max(1, int(top_k or 1))
focused_hits = [
hit
for hit in reranked
if bool((hit.metadata or {}).get("topic_bucket_same"))
or float((hit.metadata or {}).get("topic_bucket_overlap", 0.0) or 0.0) >= 0.22
or bool((hit.metadata or {}).get("topic_bucket_bridge_allowed"))
or bool((hit.metadata or {}).get("topic_bucket_dialogue_tunnel_allowed"))
or bool((hit.metadata or {}).get("topic_bucket_profile_route_preserved"))
or bool((hit.metadata or {}).get("topic_bucket_current_subject_preserved"))
or bool((hit.metadata or {}).get("current_subject_resolver"))
]
model_path_fallback = False
no_bucket_model_fallback = False
if not focused_hits:
model_supported_hits = [
hit
for hit in reranked
if _clean_text((hit.metadata or {}).get("path_id", ""))
or _clean_text((hit.metadata or {}).get("hybrid_score_source", ""))
]
if model_supported_hits:
focused_hits = model_supported_hits
model_path_fallback = True
generic_memory_fallback = False
if not focused_hits and any(not _is_public_dialog_hit(hit) for hit in reranked):
focused_hits = list(reranked)
generic_memory_fallback = True
if not focused_hits and not query_bucket_id:
focused_hits = list(reranked)
no_bucket_model_fallback = True
time_focused_model_path = any(
"speaker_event_time" in _clean_text((hit.metadata or {}).get("path_id", ""))
or _clean_text((hit.metadata or {}).get("model_focused_answer_type", "")) == "time"
for hit in focused_hits
)
if time_focused_model_path:
focused_hits = [
hit
for hit in focused_hits
if not (
_clean_text(hit.source_kind) == "public_dialog_profile"
and not _clean_text((hit.metadata or {}).get("path_id", ""))
)
]
filtered_count = max(0, len(reranked) - len(focused_hits))
query_subject_signature = _public_subject_signature(_public_query_subject(query))
def _hit_subject_signatures(hit: MemoryHit) -> set[str]:
metadata = dict(hit.metadata or {})
signatures = {
_normalize(metadata.get("subject_signature", "")).replace("-", "_"),
_public_subject_signature(metadata.get("subject", "")),
}
canonical_slot_key = _clean_text(metadata.get("canonical_slot_key", ""))
if ".subject." in canonical_slot_key:
signatures.add(_public_subject_signature(canonical_slot_key.split(".subject.", 1)[-1]))
if ".subject." in hit.slot_key:
signatures.add(_public_subject_signature(hit.slot_key.split(".subject.", 1)[-1]))
signatures.discard("")
return signatures
protected_current_hits = [
hit
for hit in focused_hits
if bool((hit.metadata or {}).get("current_subject_resolver"))
]
protected_model_path_hits = [
hit
for hit in focused_hits
if _clean_text((hit.metadata or {}).get("path_id", ""))
and _clean_text((hit.metadata or {}).get("hybrid_score_source", ""))
]
protected_generic_hits = [
hit
for hit in focused_hits
if not _is_public_dialog_hit(hit)
and not bool((hit.metadata or {}).get("profile_layer"))
]
if protected_current_hits and query_subject_signature:
protected_ids = {hit.memory_id for hit in protected_current_hits if hit.memory_id}
compacted_focused_hits: List[MemoryHit] = []
for hit in focused_hits:
if hit.memory_id in protected_ids:
compacted_focused_hits.append(hit)
continue
same_current_subject = query_subject_signature in _hit_subject_signatures(hit)
inactive_state = _normalize(hit.state) in {"superseded", "evidence", "historical", "stale", "false"}
if same_current_subject and inactive_state:
continue
compacted_focused_hits.append(hit)
focused_hits = compacted_focused_hits
filtered_count = max(0, len(reranked) - len(focused_hits))
selected_hits: List[MemoryHit] = []
selected_keys: set[str] = set()
for hit in [*protected_current_hits, *protected_model_path_hits, *protected_generic_hits, *focused_hits]:
key = hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}"
if key in selected_keys:
continue
selected_hits.append(hit)
selected_keys.add(key)
if len(selected_hits) >= limit:
break
dialogue_reserved_count = 0
if dialogue_requested and dialogue_tunnel_ids:
selected_ids = {hit.memory_id for hit in selected_hits if hit.memory_id}
selected_bucket_ids = {
_clean_text((hit.metadata or {}).get("topic_bucket_id", ""))
for hit in selected_hits
if _clean_text((hit.metadata or {}).get("topic_bucket_id", ""))
}
for bucket_id in sorted(dialogue_tunnel_ids):
if bucket_id in selected_bucket_ids:
continue
candidate = next(
(
hit
for hit in focused_hits
if hit.memory_id not in selected_ids
and _clean_text((hit.metadata or {}).get("topic_bucket_id", "")) == bucket_id
and bool((hit.metadata or {}).get("topic_bucket_dialogue_tunnel_allowed"))
),
None,
)
if candidate is None:
continue
if len(selected_hits) < limit:
selected_hits.append(candidate)
else:
replace_index = len(selected_hits) - 1
for index in range(len(selected_hits) - 1, -1, -1):
metadata = dict(selected_hits[index].metadata or {})
hardish = (
_normalize(metadata.get("memory_type", "")) == "hard_constraint"
or _normalize(metadata.get("durability", "")) == "hard"
or _normalize(metadata.get("conflict_policy", "")) == "must_preserve"
)
if not hardish and not bool(metadata.get("topic_bucket_dialogue_tunnel_allowed")):
replace_index = index
break
selected_hits[replace_index] = candidate
selected_ids.add(candidate.memory_id)
selected_bucket_ids.add(bucket_id)
dialogue_reserved_count += 1
selected_hits.sort(
key=lambda item: (
int(bool((item.metadata or {}).get("current_subject_resolver"))),
float(item.score),
int(item.turn_index),
),
reverse=True,
)
return {
"hits": selected_hits[:limit],
"metadata": {
"topic_bucket_rerank_enabled": True,
"topic_bucket_no_fill_policy": True,
"topic_bucket_query_id": query_bucket_id,
"topic_bucket_query_label": _clean_text(query_topic.get("topic_label", "")),
"topic_bucket_query_keywords": query_keywords,
"topic_bucket_adjacent_ids": sorted(adjacent_ids),
"dialogue_tunnel_adjacent_ids": sorted(dialogue_tunnel_ids),
"topic_bucket_bridge_requested": bridge_requested,
"dialogue_tunnel_requested": dialogue_requested,
"topic_bucket_model_path_fallback": model_path_fallback,
"topic_bucket_generic_memory_fallback": generic_memory_fallback,
"topic_bucket_no_bucket_model_fallback": no_bucket_model_fallback,
"topic_bucket_candidate_count": len(hits),
"topic_bucket_rescue_count": len(rescue_hits),
"dialogue_tunnel_rescue_count": len(dialogue_rescue_hits),
"profile_query_rescue_count": len(profile_rescue_hits),
"dialogue_tunnel_reserved_count": dialogue_reserved_count,
"topic_bucket_focused_count": len(focused_hits),
"topic_bucket_final_count": len(selected_hits[:limit]),
"topic_bucket_filtered_count": filtered_count,
"topic_bucket_stats": stats,
},
}
def _looks_like_write_turn(user_text: str, *, answer_payload: Dict[str, Any] | None = None) -> bool:
text = _normalize(user_text)
if not text:
return False
if "?" in text:
return False
if any(marker in text for marker in _WRITE_MARKERS):
return True
metadata = dict((answer_payload or {}).get("metadata", {}) or {})
return bool(metadata.get("memory_write"))
def _explicit_overwrite_requested(user_text: str) -> bool:
text = _normalize(user_text)
return bool(text) and any(marker in text for marker in _OVERWRITE_MARKERS)
def _apply_turn_write_intent(records: List[SessionMemoryRecordV2], *, user_text: str) -> List[SessionMemoryRecordV2]:
if not records or not _explicit_overwrite_requested(user_text):
return records
for record in records:
metadata = dict(record.metadata or {})
metadata.setdefault("write_intent", "overwrite")
metadata.setdefault("memory_gate_decision", "explicit_overwrite")
metadata["allow_parallel_state"] = False
record.metadata = metadata
return records
def _records_from_extractor(
extractor: SessionMemoryExtractor,
*,
query: str,
answer_payload: Dict[str, Any] | None,
extraction_result: Dict[str, Any] | None,
turn_index: int,
profile: TMCRAProfile | None = None,
) -> List[SessionMemoryRecordV2]:
profile = profile or TMCRAProfile()
raw_records = extractor.extract(
query=query,
extraction_result=extraction_result,
answer_bundle=answer_payload,
answer_mode=str((answer_payload or {}).get("answer_mode", "transparent")),
turn_index=turn_index,
)
results: List[SessionMemoryRecordV2] = []
for index, record in enumerate(raw_records):
slot_key = profile.stable_slot_key(
category=record.category,
value=record.value,
anchors=record.anchor_concepts,
slot_key=record.metadata.get("slot_key", "") if isinstance(record.metadata, dict) else "",
relation=record.relation,
metadata=dict(record.metadata or {}),
)
metadata = {
**dict(record.metadata or {}),
"memory_role": _clean_text(dict(record.metadata or {}).get("memory_role", "")) or "user",
"authority": _clean_text(dict(record.metadata or {}).get("authority", "")) or "source",
"canonical_slot_key": _clean_text(dict(record.metadata or {}).get("canonical_slot_key", "")) or slot_key,
"writeback_class": _clean_text(dict(record.metadata or {}).get("writeback_class", "")),
"origin_query": _clean_text(dict(record.metadata or {}).get("origin_query", "")) or _clean_text(query),
"origin_answer_id": _clean_text(dict(record.metadata or {}).get("origin_answer_id", "")),
"support_memory_ids": _dedupe(dict(record.metadata or {}).get("support_memory_ids", []) or []),
"support_fact_refs": _dedupe(dict(record.metadata or {}).get("support_fact_refs", []) or []),
"support_path_refs": _dedupe(dict(record.metadata or {}).get("support_path_refs", []) or []),
"promotion_state": _clean_text(dict(record.metadata or {}).get("promotion_state", "")) or "none",
}
results.append(
SessionMemoryRecordV2(
memory_id=f"{slot_key}:{turn_index}:auto:{index}",
category=record.category,
slot_key=slot_key,
value=_clean_text(record.value),
relation=_clean_text(record.relation) or f"{record.category}_memory",
anchor_concepts=_dedupe(record.anchor_concepts, max_items=8),
evidence_anchors=_dedupe(record.anchor_concepts, max_items=8),
salience=float(record.salience),
confidence=float(record.confidence),
source_kind=_clean_text(record.source_kind) or "session_memory",
turn_index=int(record.turn_index),
state="active",
metadata=metadata,
)
)
return results
def _build_turn_records(
extractor: SessionMemoryExtractor,
*,
user_text: str,
answer_payload: Dict[str, Any] | None,
extraction_result: Dict[str, Any] | None,
turn_index: int,
allow_auto_extract: bool,
profile: TMCRAProfile | None = None,
) -> List[SessionMemoryRecordV2]:
profile = profile or TMCRAProfile()
structured_records = _parse_structured_records(answer_payload, turn_index=turn_index, profile=profile)
if structured_records:
return _apply_typed_tunnel_annotations(
_apply_turn_write_intent(structured_records, user_text=user_text),
source_text=user_text,
)
structured_records = _parse_structured_records(extraction_result, turn_index=turn_index, profile=profile)
if structured_records:
return _apply_typed_tunnel_annotations(
_apply_turn_write_intent(structured_records, user_text=user_text),
source_text=user_text,
)
if not allow_auto_extract and not _looks_like_write_turn(user_text, answer_payload=answer_payload):
return []
return _apply_typed_tunnel_annotations(
_records_from_extractor(
extractor,
query=user_text,
answer_payload=None,
extraction_result=extraction_result,
turn_index=turn_index,
profile=profile,
),
source_text=user_text,
)
class NullMemoryAdapter(MemoryAdapter):
name = "null_memory"
def reset(self) -> None:
return None
def ingest_turn(
self,
user_text: str,
assistant_text: str = "",
*,
answer_payload: Dict[str, Any] | None = None,
extraction_result: Dict[str, Any] | None = None,
) -> None:
_ = user_text, assistant_text, answer_payload, extraction_result
def retrieve(self, query: str, *, top_k: int = 6) -> MemoryRetrieval:
_ = query, top_k
return MemoryRetrieval()
def stats(self) -> Dict[str, Any]:
return _state_stats(storage_bytes=0, retrieval_context_tokens=0, total_state_tokens=0, records=0)
def storage_bytes(self) -> int:
return 0
def build_prompt_context(self, query: str, *, top_k: int = 8) -> Dict[str, Any]:
_ = top_k
return {
"mode": "null_memory",
"query": query,
"retrieval": MemoryRetrieval().to_dict(),
"stats": self.stats(),
"state": {},
}
class GraphSessionMemoryAdapter(MemoryAdapter):
name = "graph_session_memory_v2"
def __init__(
self,
*,
auto_extract: bool = False,
storage_backend: str = "sqlite",
storage_path: str = "",
scope_id: str = "",
audit_retention: int = 256,
lightweight_stats: bool = True,
retrieval_mode: str = "heuristic",
node_model_path: str = "",
path_model_path: str = "",
node_model_device: str = "",
candidate_event_k: int = 24,
support_path_k: int = 3,
path_tunnel_rescue_k: int = 0,
path_tunnel_rescue_score_floor: float = 0.0,
path_tunnel_rescue_min_age: int = 0,
path_tunnel_rescue_min_score_margin: float = 0.0,
event_rerank_mode: str = "matrix",
matrix_event_top_k: int = DEFAULT_MATRIX_EVENT_TOP_K,
memory_router_mode: str = "",
memory_router_threshold: float = _MEMORY_ROUTER_DEFAULT_THRESHOLD,
memory_router_margin: float = _MEMORY_ROUTER_DEFAULT_MARGIN,
injection_planner_mode: str = "",
injection_planner_model_path: str = "",
injection_planner_latest_path: str = "",
injection_planner_device: str = "",
injection_planner_selection_threshold: float = -1.0,
injection_planner_row_threshold: float = -1.0,
injection_planner_logic_threshold: float = -1.0,
temporal_layer_mode: str = "",
temporal_router_mode: str = "",
temporal_router_dir: str = "",
temporal_router_latest_path: str = "",
temporal_router_device: str = "",
) -> None:
prewarm_embedder_mode = (
_normalize(os.getenv("TMCRA_EMBEDDER_INDEX_RECALL_MODE", ""))
or _normalize(os.getenv("TMCRA_WRITE_EMBEDDER_INDEX_MODE", ""))
)
self._embedder_prewarm_metadata = _prewarm_embedder_dense_if_requested(mode=prewarm_embedder_mode)
self.extractor = SessionMemoryExtractor()
self.profile = TMCRAProfile()
self.temporal_organizer = TemporalOrganizer()
self.temporal_query_planner = TemporalQueryPlanner()
self.timeline_evidence_builder = TimelineEvidencePackBuilder()
self.auto_extract = bool(auto_extract)
self.storage_backend = _normalize(storage_backend) or "sqlite"
self.audit_retention = max(1, int(audit_retention))
self.lightweight_stats = bool(lightweight_stats)
self.scope_id = _clean_text(scope_id) or f"graph-session-{uuid.uuid4().hex}"
self._store: SQLiteSessionMemoryStore | None = None
self.storage_path = ""
if self.storage_backend == "sqlite":
resolved_storage_path = _clean_text(storage_path) or str((Path(tempfile.gettempdir()) / "tmcra_graph_session_memory.sqlite3").resolve())
self._store = SQLiteSessionMemoryStore(resolved_storage_path, audit_retention=self.audit_retention)
self.storage_path = str(self._store.storage_path)
self.graph = self._store.load_graph(self.scope_id)
elif self.storage_backend == "memory":
self.storage_path = _clean_text(storage_path)
self.graph = SessionMemoryGraphV2(
audit_retention=self.audit_retention,
persistence_backend="memory",
persistence_path=self.storage_path,
)
else:
raise ValueError(f"Unsupported storage backend: {self.storage_backend}")
self._last_retrieval_context_tokens = 0
self._last_writeback_summary: Dict[str, Any] = {}
self.retrieval_mode = _normalize(retrieval_mode) or "heuristic"
self.node_model_path = _clean_text(node_model_path)
self.path_model_path = _clean_text(path_model_path)
self.node_model_device = _clean_text(node_model_device)
self.candidate_event_k = max(1, int(candidate_event_k))
self.support_path_k = max(1, int(support_path_k))
self.path_tunnel_rescue_k = max(0, int(path_tunnel_rescue_k or 0))
self.path_tunnel_rescue_score_floor = max(0.0, float(path_tunnel_rescue_score_floor or 0.0))
self.path_tunnel_rescue_min_age = max(0, int(path_tunnel_rescue_min_age or 0))
self.path_tunnel_rescue_min_score_margin = max(0.0, float(path_tunnel_rescue_min_score_margin or 0.0))
self.event_rerank_mode = _normalize(event_rerank_mode) or "matrix"
self.matrix_event_top_k = max(1, int(matrix_event_top_k or DEFAULT_MATRIX_EVENT_TOP_K))
self.write_embedder_index_mode = _normalize(os.getenv("TMCRA_WRITE_EMBEDDER_INDEX_MODE", ""))
if not self.write_embedder_index_mode:
self.write_embedder_index_mode = "off"
try:
self.write_embedder_index_max_terms = max(
8,
int(os.getenv("TMCRA_WRITE_EMBEDDER_INDEX_MAX_TERMS", "96") or 96),
)
except (TypeError, ValueError):
self.write_embedder_index_max_terms = 96
env_embedder_recall_mode = _normalize(os.getenv("TMCRA_EMBEDDER_INDEX_RECALL_MODE", ""))
self.embedder_index_recall_mode = env_embedder_recall_mode or self.write_embedder_index_mode
try:
self.embedder_index_recall_k = max(
0,
int(os.getenv("TMCRA_EMBEDDER_INDEX_RECALL_K", "0") or 0),
)
except (TypeError, ValueError):
self.embedder_index_recall_k = 0
self.embedder_pre_recall_mode = _normalize(os.getenv("TMCRA_EMBEDDER_PRE_RECALL_MODE", ""))
if not self.embedder_pre_recall_mode:
self.embedder_pre_recall_mode = "off"
try:
self.embedder_pre_recall_k = max(
0,
int(os.getenv("TMCRA_EMBEDDER_PRE_RECALL_K", "0") or 0),
)
except (TypeError, ValueError):
self.embedder_pre_recall_k = 0
self.embedder_fusion_mode = _normalize(os.getenv("TMCRA_EMBEDDER_FUSION_MODE", ""))
try:
self.embedder_fusion_weight = max(
0.0,
float(os.getenv("TMCRA_EMBEDDER_FUSION_WEIGHT", "0.35") or 0.35),
)
except (TypeError, ValueError):
self.embedder_fusion_weight = 0.35
try:
self.embedder_fusion_score_floor = max(
0.0,
float(os.getenv("TMCRA_EMBEDDER_FUSION_SCORE_FLOOR", "0.62") or 0.62),
)
except (TypeError, ValueError):
self.embedder_fusion_score_floor = 0.62
try:
self.embedder_fusion_top_k = max(
0,
int(os.getenv("TMCRA_EMBEDDER_FUSION_TOP_K", "16") or 16),
)
except (TypeError, ValueError):
self.embedder_fusion_top_k = 16
try:
self.embedder_fusion_select_k = max(
0,
int(os.getenv("TMCRA_EMBEDDER_FUSION_SELECT_K", "4") or 4),
)
except (TypeError, ValueError):
self.embedder_fusion_select_k = 4
try:
self.embedder_fusion_max_boost = max(
0.0,
float(os.getenv("TMCRA_EMBEDDER_FUSION_MAX_BOOST", "0.42") or 0.42),
)
except (TypeError, ValueError):
self.embedder_fusion_max_boost = 0.42
env_router_mode = _clean_text(os.getenv("TMCRA_MEMORY_ROUTER_MODE", ""))
self.memory_router_mode = _normalize(memory_router_mode) or _normalize(env_router_mode) or "observe"
try:
self.memory_router_threshold = float(
os.getenv("TMCRA_MEMORY_ROUTER_THRESHOLD", "")
or memory_router_threshold
or _MEMORY_ROUTER_DEFAULT_THRESHOLD
)
except (TypeError, ValueError):
self.memory_router_threshold = _MEMORY_ROUTER_DEFAULT_THRESHOLD
try:
self.memory_router_margin = float(
os.getenv("TMCRA_MEMORY_ROUTER_MARGIN", "")
or memory_router_margin
or _MEMORY_ROUTER_DEFAULT_MARGIN
)
except (TypeError, ValueError):
self.memory_router_margin = _MEMORY_ROUTER_DEFAULT_MARGIN
self._loaded_node_scorer: LoadedNodeMemoryScorer | None = None
self._node_scorer_error = ""
env_planner_mode = _clean_text(os.getenv("TMCRA_INJECTION_PLANNER_MODE", ""))
self.injection_planner_mode = _normalize(injection_planner_mode) or _normalize(env_planner_mode) or "observe"
self.injection_planner_model_path = _clean_text(
injection_planner_model_path or os.getenv("TMCRA_INJECTION_PLANNER_MODEL_PATH", "")
)
self.injection_planner_latest_path = _clean_text(
injection_planner_latest_path or os.getenv("TMCRA_INJECTION_PLANNER_LATEST", "")
)
self.injection_planner_device = _clean_text(
injection_planner_device or os.getenv("TMCRA_INJECTION_PLANNER_DEVICE", "")
)
self.injection_planner_selection_threshold_override = float(injection_planner_selection_threshold)
self.injection_planner_row_threshold_override = float(injection_planner_row_threshold)
self.injection_planner_logic_threshold_override = float(injection_planner_logic_threshold)
self._loaded_injection_planner: Any | None = None
self._injection_planner_config: Any | None = None
self._injection_planner_payload: Dict[str, Any] = {}
self._injection_planner_thresholds: Dict[str, float] = {}
self._injection_planner_resolved_path = ""
self._injection_planner_error = ""
self._injection_planner_evidence_role_supported = False
env_temporal_layer_mode = _clean_text(os.getenv("TMCRA_TEMPORAL_LAYER_MODE", ""))
self.temporal_layer_mode = _normalize(temporal_layer_mode) or _normalize(env_temporal_layer_mode) or "observe"
env_temporal_router_mode = _clean_text(os.getenv("TMCRA_TEMPORAL_ROUTER_MODE", ""))
self.temporal_router_mode = _normalize(temporal_router_mode) or _normalize(env_temporal_router_mode) or "observe"
self.temporal_router_dir = _clean_text(
temporal_router_dir or os.getenv("TMCRA_TEMPORAL_ROUTER_DIR", "")
)
self.temporal_router_latest_path = _clean_text(
temporal_router_latest_path
or os.getenv("TMCRA_TEMPORAL_ROUTER_LATEST", "")
or "models/temporal_router_v1_latest.txt"
)
self.temporal_router_device = _clean_text(
temporal_router_device or os.getenv("TMCRA_TEMPORAL_ROUTER_DEVICE", "")
) or "cpu"
self._loaded_temporal_router: Any | None = None
self._temporal_router_resolved_dir = ""
self._temporal_router_error = ""
def _empty_graph(self) -> SessionMemoryGraphV2:
return SessionMemoryGraphV2(
audit_retention=self.audit_retention,
persistence_backend=self.storage_backend,
persistence_path=self.storage_path,
)
def _reload_graph(self) -> None:
if self._store is not None:
self.graph = self._store.load_graph(self.scope_id)
else:
self.graph.configure_persistence(
backend=self.storage_backend,
path=self.storage_path,
audit_retention=self.audit_retention,
)
def _persist_graph(self) -> None:
self.graph.configure_persistence(
backend=self.storage_backend,
path=self.storage_path,
audit_retention=self.audit_retention,
)
if self._store is not None:
self._store.save_graph(self.scope_id, self.graph)
def replace_graph(self, graph: SessionMemoryGraphV2) -> None:
self.graph = graph
self._persist_graph()
def _storage_breakdown(self) -> Dict[str, int]:
core_payload = self.graph._core_payload()
audit_payload_full = self.graph._audit_payload()
if self.lightweight_stats:
audit_token_payload = {
"totals": dict(self.graph.audit_event_totals),
"trimmed": dict(self.graph.audit_trimmed_counts),
"retained": {
"turn_log": len(self.graph.turn_log),
"retrieval_log": len(self.graph.retrieval_log),
"answer_support_log": len(self.graph.answer_support_log),
},
"audit_retention": int(self.graph.audit_retention),
}
else:
audit_token_payload = audit_payload_full
core_storage_bytes = len(json.dumps(core_payload, ensure_ascii=False).encode("utf-8"))
audit_storage_bytes = len(json.dumps(audit_payload_full, ensure_ascii=False).encode("utf-8"))
core_state_token_estimate = _estimate_tokens(json.dumps(core_payload, ensure_ascii=False))
audit_state_token_estimate = _estimate_tokens(json.dumps(audit_token_payload, ensure_ascii=False))
return {
"core_storage_bytes": int(core_storage_bytes),
"audit_storage_bytes": int(audit_storage_bytes),
"storage_bytes": int(core_storage_bytes + audit_storage_bytes),
"core_state_token_estimate": int(core_state_token_estimate),
"audit_state_token_estimate": int(audit_state_token_estimate),
"total_state_token_estimate": int(core_state_token_estimate + audit_state_token_estimate),
}
def reset(self) -> None:
if self._store is not None:
self._store.clear_scope(self.scope_id)
self.graph = self._store.load_graph(self.scope_id)
else:
self.graph = self._empty_graph()
self._last_retrieval_context_tokens = 0
self._last_writeback_summary = {}
def _temporal_layer_enabled(self) -> bool:
return _normalize(self.temporal_layer_mode) not in _TEMPORAL_LAYER_DISABLED_MODES
def _temporal_router_enabled(self) -> bool:
return self._temporal_layer_enabled() and _normalize(self.temporal_router_mode) not in _TEMPORAL_LAYER_DISABLED_MODES
def _resolve_temporal_router_dir(self) -> str:
direct_dir = _clean_text(self.temporal_router_dir)
if direct_dir:
root = Path(direct_dir)
if (root / "writer_temporal_router.pt").exists() and (root / "query_temporal_router.pt").exists():
return str(root)
latest_path = _clean_text(self.temporal_router_latest_path)
if latest_path:
pointer = Path(latest_path)
if pointer.exists():
lines = pointer.read_text(encoding="utf-8").splitlines()
candidate = _clean_text(lines[0] if lines else "")
if candidate:
root = Path(candidate)
if (root / "writer_temporal_router.pt").exists() and (root / "query_temporal_router.pt").exists():
return str(root)
return ""
def _load_temporal_router(self) -> LoadedTemporalRouter | None:
if not self._temporal_router_enabled():
self._temporal_router_error = "disabled"
return None
model_dir = self._resolve_temporal_router_dir()
self._temporal_router_resolved_dir = model_dir
if not model_dir:
self._temporal_router_error = "model_dir_missing"
return None
if self._loaded_temporal_router is not None:
return self._loaded_temporal_router
try:
self._loaded_temporal_router = LoadedTemporalRouter.from_dir(
model_dir,
device=self.temporal_router_device or "cpu",
)
self._temporal_router_error = ""
except Exception as exc: # pragma: no cover - defensive runtime path
self._loaded_temporal_router = None
self._temporal_router_error = f"{type(exc).__name__}: {exc}"
return self._loaded_temporal_router
def _temporal_router_status_metadata(self) -> Dict[str, Any]:
router = self._load_temporal_router()
return {
"temporal_router_enabled": router is not None,
"temporal_router_mode": _normalize(self.temporal_router_mode) or "observe",
"temporal_router_model_dir": self._temporal_router_resolved_dir,
"temporal_router_error": self._temporal_router_error,
}
def _session_timestamp_from_payloads(self, *payloads: Mapping[str, Any] | None) -> str:
for payload in payloads:
data = dict(payload or {})
for key in ("session_timestamp", "timestamp", "created_at", "turn_timestamp"):
value = _clean_text(data.get(key, ""))
if value:
return value
metadata = data.get("metadata")
if isinstance(metadata, Mapping):
for key in ("session_timestamp", "timestamp", "created_at", "turn_timestamp"):
value = _clean_text(metadata.get(key, ""))
if value:
return value
return ""
def _temporal_frame_for_turn(
self,
*,
user_text: str,
previous_turn: str = "",
answer_payload: Dict[str, Any] | None = None,
extraction_result: Dict[str, Any] | None = None,
) -> TemporalFrame | None:
if not self._temporal_layer_enabled():
return None
session_timestamp = self._session_timestamp_from_payloads(answer_payload, extraction_result)
model_frame = None
for payload in (answer_payload, extraction_result):
data = dict(payload or {})
candidate = data.get("temporal_frame") or dict(data.get("metadata", {}) or {}).get("temporal_frame")
if isinstance(candidate, Mapping):
model_frame = candidate
break
fallback_frame = self.temporal_organizer.organize_turn(
current_turn=user_text,
previous_turn=previous_turn,
session_timestamp=session_timestamp,
speaker="user",
)
if model_frame is None:
router = self._load_temporal_router()
if router is not None and router.writer_available():
predicted = router.predict_writer_frame(
current_turn=user_text,
previous_turn=previous_turn,
session_timestamp=session_timestamp,
)
writer_confidence = float(predicted.get("confidence", 0.0) or 0.0) if predicted else 0.0
if predicted and writer_confidence >= _float_env("TMCRA_TEMPORAL_ROUTER_WRITER_MIN_CONFIDENCE", _TEMPORAL_ROUTER_DEFAULT_WRITER_MIN_CONFIDENCE):
frame_payload = fallback_frame.to_dict()
frame_payload.update(
{
key: value
for key, value in predicted.items()
if key in {"temporal_intent", "anchor_type", "granularity", "state_operation"}
and _clean_text(value)
}
)
if "should_create_timeline_edge" in predicted:
frame_payload["should_create_timeline_edge"] = bool(predicted.get("should_create_timeline_edge", False))
if writer_confidence > 0.0:
frame_payload["confidence"] = writer_confidence
frame_payload["metadata"] = {
**dict(frame_payload.get("metadata", {}) or {}),
**dict(predicted.get("metadata", {}) or {}),
**self._temporal_router_status_metadata(),
}
model_frame = frame_payload
if model_frame is None:
return fallback_frame
return self.temporal_organizer.organize_turn(
current_turn=user_text,
previous_turn=previous_turn,
session_timestamp=session_timestamp,
speaker="user",
model_frame=model_frame,
)
def _temporal_turn_metadata(self, frame: TemporalFrame | None) -> Dict[str, Any]:
if frame is None:
return {
"temporal_layer_enabled": False,
"temporal_layer_mode": _normalize(self.temporal_layer_mode) or "observe",
**self._temporal_router_status_metadata(),
}
return {
"temporal_layer_enabled": True,
"temporal_layer_mode": _normalize(self.temporal_layer_mode) or "observe",
"temporal_frame": frame.to_dict(),
"temporal_intent": frame.temporal_intent,
"temporal_subject_key": frame.subject_key,
"temporal_state_operation": frame.state_operation,
**self._temporal_router_status_metadata(),
}
def _apply_temporal_frame_to_records(self, records: Sequence[SessionMemoryRecordV2], frame: TemporalFrame | None) -> None:
if frame is None:
return
if frame.temporal_intent == "non_temporal" and not frame.subject_key:
return
for record in records:
metadata = dict(record.metadata or {})
metadata.update(
{
"temporal_layer": True,
"temporal_frame": frame.to_dict(),
"temporal_intent": frame.temporal_intent,
"temporal_subject": frame.subject,
"temporal_subject_key": frame.subject_key,
"temporal_state_operation": frame.state_operation,
"temporal_event_time": frame.event_time,
"temporal_state_valid_from": frame.state_valid_from,
"temporal_state_valid_to": frame.state_valid_to,
}
)
record.metadata = metadata
def _build_timeline_state_layer(self) -> TimelineStateLayer:
layer = TimelineStateLayer()
for turn in sorted(list(getattr(self.graph, "turn_log", []) or []), key=lambda item: int(item.get("turn_index", 0) or 0)):
metadata = dict(turn.get("metadata", {}) or {})
frame_payload = metadata.get("temporal_frame")
if not isinstance(frame_payload, Mapping):
continue
frame = TemporalFrame.from_mapping(frame_payload)
if frame.temporal_intent == "non_temporal" and not frame.new_state:
continue
record_ids = [item for item in list(turn.get("record_ids", []) or []) if _clean_text(item)]
source_event_id = _clean_text(metadata.get("temporal_source_record_id", "")) or (record_ids[0] if record_ids else "")
frame.metadata = {
**dict(frame.metadata or {}),
"source_text": _clean_text(turn.get("text", "")) or frame.evidence_span,
}
layer.apply_frame(
frame,
source_event_id=source_event_id,
source_turn_id=_clean_text(turn.get("turn_id", "")),
state_type=_clean_text(metadata.get("temporal_state_type", "")) or "profile",
)
return layer
def _temporal_runtime_pack(self, query: str) -> Dict[str, Any]:
metadata: Dict[str, Any] = {
"temporal_runtime_enabled": False,
"temporal_layer_mode": _normalize(self.temporal_layer_mode) or "observe",
**self._temporal_router_status_metadata(),
}
if not self._temporal_layer_enabled():
metadata["temporal_runtime_reason"] = "disabled"
return {"metadata": metadata}
plan = self.temporal_query_planner.plan(query)
router = self._load_temporal_router()
if router is not None and router.query_available():
predicted_plan = router.predict_query_plan(query=query)
query_confidence = float(predicted_plan.get("confidence", 0.0) or 0.0) if predicted_plan else 0.0
router_confidences = dict(dict(predicted_plan.get("metadata", {}) or {}).get("temporal_router_confidences", {}) or {}) if predicted_plan else {}
intent_confidence = float(router_confidences.get("query_temporal_intent", 0.0) or 0.0)
if (
predicted_plan
and query_confidence >= _float_env("TMCRA_TEMPORAL_ROUTER_QUERY_MIN_CONFIDENCE", _TEMPORAL_ROUTER_DEFAULT_QUERY_MIN_CONFIDENCE)
and intent_confidence >= _float_env("TMCRA_TEMPORAL_ROUTER_QUERY_INTENT_MIN_CONFIDENCE", _TEMPORAL_ROUTER_DEFAULT_QUERY_INTENT_MIN_CONFIDENCE)
):
plan_payload = plan.to_dict()
plan_payload.update(
{
key: value
for key, value in predicted_plan.items()
if key in {"query_temporal_intent", "timeline_operation"} and _clean_text(value)
}
)
for key in ("prefer_current_state", "prefer_previous_state", "requires_ordered_chain", "requires_comparison"):
if key in predicted_plan:
plan_payload[key] = bool(predicted_plan.get(key, False))
if float(predicted_plan.get("confidence", 0.0) or 0.0) > 0.0:
plan_payload["confidence"] = float(predicted_plan.get("confidence", 0.0) or 0.0)
plan_payload["metadata"] = {
**dict(plan_payload.get("metadata", {}) or {}),
**dict(predicted_plan.get("metadata", {}) or {}),
**self._temporal_router_status_metadata(),
}
plan = TemporalQueryPlan(**{key: value for key, value in plan_payload.items() if key in TemporalQueryPlan.__dataclass_fields__})
metadata["temporal_query_plan"] = plan.to_dict()
if plan.query_temporal_intent == "non_temporal" or plan.timeline_operation == "none":
metadata["temporal_runtime_reason"] = "non_temporal_query"
return {"plan": plan, "metadata": metadata}
if plan.timeline_operation in {"query_current", "query_previous"} and not _clean_text(plan.target_subject_key):
metadata["temporal_runtime_reason"] = "missing_target_subject"
return {"plan": plan, "metadata": metadata}
timeline_layer = self._build_timeline_state_layer()
pack = self.timeline_evidence_builder.build(plan=plan, timeline_layer=timeline_layer)
metadata.update(
{
"temporal_runtime_enabled": True,
"temporal_runtime_reason": "ok",
"temporal_evidence_pack": pack.to_dict(),
"temporal_selected_answer_value": _clean_text(pack.selected_evidence.get("answer_value", "")),
"temporal_selected_state_id": _clean_text(pack.selected_evidence.get("state_id", "")),
"temporal_timeline_state_count": len(pack.timeline),
}
)
return {"plan": plan, "pack": pack, "metadata": metadata}
def _temporal_state_hit(
self,
*,
state_payload: Mapping[str, Any],
plan_payload: Mapping[str, Any],
selected: bool,
rank: int,
) -> MemoryHit | None:
state_value = _clean_text(state_payload.get("state", ""))
state_id = _clean_text(state_payload.get("state_id", ""))
if not state_value or not state_id:
return None
source_event_id = _clean_text(state_payload.get("source_event_id", ""))
source_record = self.graph.records_by_id.get(source_event_id)
source_metadata = dict(source_record.metadata or {}) if source_record is not None else {}
score = 0.98 if selected else max(0.75, 0.92 - (rank * 0.03))
return MemoryHit(
memory_id=f"temporal_state:{state_id}",
category="time",
value=state_value,
relation="temporal_state",
anchors=_dedupe([state_payload.get("time", ""), state_payload.get("source_text", ""), plan_payload.get("target_subject", "")], max_items=6),
score=round(score, 6),
source_kind="temporal_state_layer",
slot_key=f"temporal.{_clean_text(plan_payload.get('target_subject_key', 'general'))}",
state="active" if bool(state_payload.get("is_current", False)) else "history",
turn_index=int(source_record.turn_index) if source_record is not None else 0,
metadata={
**source_metadata,
"temporal_runtime_hit": True,
"temporal_runtime_selected": bool(selected),
"temporal_state_id": state_id,
"temporal_state_value": state_value,
"temporal_state_time": _clean_text(state_payload.get("time", "")),
"temporal_state_valid_to": _clean_text(state_payload.get("valid_to", "")),
"temporal_state_is_current": bool(state_payload.get("is_current", False)),
"temporal_source_event_id": source_event_id,
"temporal_source_turn_id": _clean_text(state_payload.get("source_turn_id", "")),
"temporal_query_plan": dict(plan_payload),
},
)
def _apply_temporal_evidence_pack_to_hits(
self,
hits: Sequence[MemoryHit],
temporal_payload: Mapping[str, Any],
*,
top_k: int,
) -> Dict[str, Any]:
metadata = dict(temporal_payload.get("metadata", {}) or {})
pack = temporal_payload.get("pack")
plan = temporal_payload.get("plan")
if pack is None or plan is None or not bool(metadata.get("temporal_runtime_enabled", False)):
return {"hits": list(hits), "metadata": metadata}
pack_payload = pack.to_dict() if hasattr(pack, "to_dict") else dict(pack)
plan_payload = plan.to_dict() if hasattr(plan, "to_dict") else dict(plan)
selected_state_id = _clean_text(dict(pack_payload.get("selected_evidence", {}) or {}).get("state_id", ""))
selected_answer_value = _clean_text(dict(pack_payload.get("selected_evidence", {}) or {}).get("answer_value", ""))
timeline = list(pack_payload.get("timeline", []) or [])
synthetic_hits: List[MemoryHit] = []
if selected_state_id:
selected_state = next((dict(item) for item in timeline if _clean_text(dict(item).get("state_id", "")) == selected_state_id), None)
if selected_state is None and selected_answer_value:
selected_state = {
"state_id": selected_state_id,
"state": selected_answer_value,
"source_event_id": dict(pack_payload.get("selected_evidence", {}) or {}).get("source_event_id", ""),
"source_turn_id": dict(pack_payload.get("selected_evidence", {}) or {}).get("source_turn_id", ""),
"is_current": bool(plan_payload.get("prefer_current_state", False)),
}
if selected_state:
hit = self._temporal_state_hit(state_payload=selected_state, plan_payload=plan_payload, selected=True, rank=0)
if hit is not None:
synthetic_hits.append(hit)
if bool(plan_payload.get("requires_ordered_chain", False)) or bool(plan_payload.get("requires_comparison", False)):
for index, state_payload in enumerate(timeline):
if _clean_text(dict(state_payload).get("state_id", "")) == selected_state_id:
continue
hit = self._temporal_state_hit(
state_payload=dict(state_payload),
plan_payload=plan_payload,
selected=False,
rank=index + 1,
)
if hit is not None:
synthetic_hits.append(hit)
existing: List[MemoryHit] = []
synthetic_source_ids = {
_clean_text((hit.metadata or {}).get("temporal_source_event_id", ""))
for hit in synthetic_hits
if _clean_text((hit.metadata or {}).get("temporal_source_event_id", ""))
}
for hit in hits:
hit_metadata = dict(hit.metadata or {})
source_event_id = _clean_text(hit_metadata.get("event_id", hit.memory_id))
if source_event_id in synthetic_source_ids:
hit_metadata.update(
{
"temporal_runtime_support": True,
"temporal_query_plan": dict(plan_payload),
"temporal_selected_answer_value": selected_answer_value,
"temporal_selected_state_id": selected_state_id,
}
)
hit = MemoryHit(
memory_id=hit.memory_id,
category=hit.category,
value=hit.value,
relation=hit.relation,
anchors=list(hit.anchors),
score=hit.score,
source_kind=hit.source_kind,
slot_key=hit.slot_key,
state=hit.state,
turn_index=int(hit.turn_index),
metadata=hit_metadata,
)
existing.append(hit)
merged: List[MemoryHit] = []
seen = set()
for hit in [*synthetic_hits, *existing]:
key = _clean_text(hit.memory_id)
if key and key in seen:
continue
if key:
seen.add(key)
merged.append(hit)
metadata.update(
{
"temporal_runtime_injected_hit_count": len(synthetic_hits),
"temporal_runtime_injected_hit_ids": [hit.memory_id for hit in synthetic_hits],
}
)
return {"hits": merged[: max(len(merged), top_k)], "metadata": metadata}
def ingest_turn(
self,
user_text: str,
assistant_text: str = "",
*,
answer_payload: Dict[str, Any] | None = None,
extraction_result: Dict[str, Any] | None = None,
) -> None:
self._reload_graph()
turn_index = self.graph.next_turn()
previous_topic = _last_topic_turn(self.graph)
previous_turn_text = _clean_text(self.graph.turn_log[-1].get("text", "")) if self.graph.turn_log else ""
temporal_frame = self._temporal_frame_for_turn(
user_text=user_text,
previous_turn=previous_turn_text,
answer_payload=answer_payload,
extraction_result=extraction_result,
)
topic_bucket = _assign_topic_bucket_for_text(
self.graph,
user_text,
answer_payload=answer_payload,
turn_index=turn_index,
create=True,
)
records = _build_turn_records(
self.extractor,
user_text=user_text,
answer_payload=answer_payload,
extraction_result=extraction_result,
turn_index=turn_index,
allow_auto_extract=self.auto_extract,
profile=self.profile,
)
self._apply_temporal_frame_to_records(records, temporal_frame)
_apply_topic_bucket_to_records(records, topic_bucket)
stored_ids = self.graph.add_records(records)
write_embedder_metadata = _apply_write_embedder_index_to_graph(
self.graph,
stored_ids=stored_ids,
turn_text=user_text,
turn_index=turn_index,
mode=self.write_embedder_index_mode,
max_terms=self.write_embedder_index_max_terms,
)
topic_bridge_metadata = _add_topic_bridge_edges(
self.graph,
previous_topic=previous_topic,
current_topic=topic_bucket,
current_record_ids=stored_ids,
turn_index=turn_index,
evidence=user_text,
)
dialogue_tunnel_metadata = _add_dialogue_tunnel_edges(
self.graph,
current_topic=topic_bucket,
current_record_ids=stored_ids,
turn_index=turn_index,
evidence=user_text,
)
turn_kind = "memory_write" if stored_ids else "noise"
self.graph.record_turn(
turn_kind=turn_kind,
text=user_text,
turn_index=turn_index,
record_ids=stored_ids,
speaker="user",
assistant_text=assistant_text,
metadata={
"source": "user_turn",
"auto_extract": bool(self.auto_extract),
"topic_bucket_id": _clean_text(topic_bucket.get("topic_bucket_id", "")),
"topic_label": _clean_text(topic_bucket.get("topic_label", "")),
"topic_keywords": list(topic_bucket.get("topic_keywords", []) or []),
"temporal_source_record_id": stored_ids[0] if stored_ids else "",
**self._temporal_turn_metadata(temporal_frame),
**write_embedder_metadata,
**topic_bridge_metadata,
**dialogue_tunnel_metadata,
},
)
self._persist_graph()
def ingest_answer_writeback(
self,
*,
query_text: str,
answer_text: str,
answer_id: str,
writeback_records: List[Dict[str, Any]],
trace: Dict[str, Any] | None = None,
) -> List[str]:
self._reload_graph()
if not writeback_records:
self._last_writeback_summary = {"stored_record_ids": [], "promotion_events": []}
return []
turn_index = self.graph.next_turn()
records: List[SessionMemoryRecordV2] = []
writeback_classes: List[str] = []
for index, raw in enumerate(writeback_records):
if not isinstance(raw, dict):
continue
category = _clean_text(raw.get("category", "fact")) or "fact"
value = _clean_text(raw.get("value", ""))
raw_slot_key = _clean_text(raw.get("slot_key", "")) or _clean_text(raw.get("slot", ""))
anchors = _dedupe(raw.get("anchors", []) or [], max_items=8)
slot_key = self.profile.stable_slot_key(
category=category,
value=value,
anchors=anchors,
slot_key=raw_slot_key,
relation=_clean_text(raw.get("relation", "")),
metadata=dict(raw.get("metadata", {}) or {}),
)
if not value or not slot_key:
continue
raw_metadata = dict(raw.get("metadata", {}) or {})
writeback_class = _clean_text(raw_metadata.get("writeback_class", "")) or "fact"
writeback_classes.append(writeback_class)
metadata = {
**raw_metadata,
"memory_role": _clean_text(raw_metadata.get("memory_role", "")) or "assistant",
"authority": _clean_text(raw_metadata.get("authority", "")) or "derived",
"canonical_slot_key": _clean_text(raw_metadata.get("canonical_slot_key", "")) or slot_key.removeprefix("assistant.").split(".fact", 1)[0].split(".state_change", 1)[0].split(".high_conf_conclusion", 1)[0],
"writeback_class": writeback_class,
"origin_query": _clean_text(raw_metadata.get("origin_query", "")) or _clean_text(query_text),
"origin_answer_id": _clean_text(raw_metadata.get("origin_answer_id", "")) or answer_id,
"origin_answer_ids": _dedupe([*(raw_metadata.get("origin_answer_ids", []) or []), _clean_text(raw_metadata.get("origin_answer_id", "")) or answer_id]),
"support_memory_ids": _dedupe(raw_metadata.get("support_memory_ids", []) or []),
"support_fact_refs": _dedupe(raw_metadata.get("support_fact_refs", []) or []),
"support_path_refs": _dedupe(raw_metadata.get("support_path_refs", []) or []),
"promotion_state": _clean_text(raw_metadata.get("promotion_state", "")) or "candidate",
"answer_id": answer_id,
}
record = SessionMemoryRecordV2(
memory_id=f"{slot_key}:{turn_index}:assistant:{index}",
category=category,
slot_key=slot_key,
value=value,
relation=_clean_text(raw.get("relation", "")) or "assistant_memory",
anchor_concepts=anchors,
evidence_anchors=anchors,
salience=float(raw.get("salience", 0.62) or 0.62),
confidence=float(raw.get("confidence", 0.88) or 0.88),
source_kind=_clean_text(raw.get("source_kind", "")) or "assistant_writeback",
turn_index=turn_index,
state=_clean_text(raw.get("state", "")) or "active",
metadata=metadata,
)
records.append(record)
stored_ids = self.graph.add_records(records)
write_embedder_metadata = _apply_write_embedder_index_to_graph(
self.graph,
stored_ids=stored_ids,
turn_text=" ".join([_clean_text(query_text), _clean_text(answer_text)]),
turn_index=turn_index,
mode=self.write_embedder_index_mode,
max_terms=self.write_embedder_index_max_terms,
)
promotion_events = self._apply_writeback_promotions(stored_ids)
writeback_class = writeback_classes[0] if len(set(writeback_classes)) == 1 and writeback_classes else ("mixed" if writeback_classes else "")
self.graph.record_turn(
turn_kind="assistant_writeback" if stored_ids else "assistant_writeback_empty",
text=query_text,
turn_index=turn_index,
record_ids=stored_ids,
speaker="assistant",
assistant_text=answer_text,
writeback_class=writeback_class,
metadata={
"source": "assistant_writeback",
"answer_id": answer_id,
"trace": dict(trace or {}),
**write_embedder_metadata,
},
)
self._last_writeback_summary = {
"stored_record_ids": list(stored_ids),
"promotion_events": list(promotion_events),
**write_embedder_metadata,
}
self._persist_graph()
return stored_ids
def last_writeback_summary(self) -> Dict[str, Any]:
return dict(self._last_writeback_summary)
def _resolve_injection_planner_model_path(self) -> str:
explicit_path = _clean_text(self.injection_planner_model_path)
if explicit_path:
path = Path(explicit_path).expanduser()
if path.is_dir():
path = path / "injection_planner.pt"
return str(path)
latest_path = _clean_text(self.injection_planner_latest_path)
if not latest_path:
return ""
latest = Path(latest_path).expanduser()
if latest.is_dir():
return str(latest / "injection_planner.pt")
if latest.suffix == ".pt":
return str(latest)
try:
target = Path(latest.read_text(encoding="utf-8").strip()).expanduser()
except Exception as exc:
self._injection_planner_error = f"latest_pointer_read_failed: {exc}"
return ""
if target.is_dir():
target = target / "injection_planner.pt"
return str(target)
def _load_injection_planner(self) -> Any | None:
normalized_mode = _normalize(self.injection_planner_mode)
if normalized_mode in _INJECTION_PLANNER_DISABLED_MODES:
self._injection_planner_error = "disabled"
return None
if self._loaded_injection_planner is not None:
return self._loaded_injection_planner
torch_module = getattr(injection_planner_runtime, "torch", None)
if torch_module is None:
self._injection_planner_error = "torch_unavailable"
return None
model_path_text = self._resolve_injection_planner_model_path()
if not model_path_text:
if not self._injection_planner_error:
self._injection_planner_error = "model_path_missing"
return None
model_path = Path(model_path_text)
if not model_path.exists():
self._injection_planner_error = f"model_path_not_found: {model_path}"
return None
try:
device = torch_module.device(self.injection_planner_device or "cpu")
payload = torch_module.load(model_path, map_location=device, weights_only=False)
config = injection_planner_runtime.InjectionPlannerConfig.from_dict(dict(payload.get("config", {}) or {}))
model = injection_planner_runtime.InjectionPlannerModel(config).to(device)
state_dict = dict(payload.get("state_dict", {}) or {})
model.load_state_dict(state_dict, strict=False)
model.eval()
self._loaded_injection_planner = model
self._injection_planner_config = config
self._injection_planner_payload = dict(payload)
self._injection_planner_resolved_path = str(model_path)
self._injection_planner_error = ""
self._injection_planner_evidence_role_supported = any(
str(key).startswith("evidence_role_head.") for key in state_dict
)
thresholds = {"selection_threshold": 0.5, "row_threshold": 0.5, "logic_threshold": 0.5}
summary_path = model_path.parent / "train_summary.json"
if summary_path.exists():
try:
summary = json.loads(summary_path.read_text(encoding="utf-8"))
calibration = dict(summary.get("calibration", {}) or {})
logic_calibration = dict(summary.get("logic_calibration", {}) or {})
if calibration.get("selection_threshold") is not None:
thresholds["selection_threshold"] = float(calibration.get("selection_threshold"))
if calibration.get("row_threshold") is not None:
thresholds["row_threshold"] = float(calibration.get("row_threshold"))
if logic_calibration.get("logic_threshold") is not None:
thresholds["logic_threshold"] = float(logic_calibration.get("logic_threshold"))
except Exception:
pass
if self.injection_planner_selection_threshold_override >= 0.0:
thresholds["selection_threshold"] = float(self.injection_planner_selection_threshold_override)
if self.injection_planner_row_threshold_override >= 0.0:
thresholds["row_threshold"] = float(self.injection_planner_row_threshold_override)
if self.injection_planner_logic_threshold_override >= 0.0:
thresholds["logic_threshold"] = float(self.injection_planner_logic_threshold_override)
self._injection_planner_thresholds = thresholds
except Exception as exc:
self._loaded_injection_planner = None
self._injection_planner_config = None
self._injection_planner_payload = {}
self._injection_planner_resolved_path = str(model_path)
self._injection_planner_error = str(exc)
self._injection_planner_evidence_role_supported = False
return self._loaded_injection_planner
def _injection_planner_candidate_from_hit(
self,
query: str,
hit: MemoryHit,
*,
index: int,
current_turn_index: int,
) -> Dict[str, Any]:
metadata = dict(hit.metadata or {})
category = _normalize(hit.category)
source_kind = _normalize(hit.source_kind)
logic_roles = [
_normalize(item)
for item in (
metadata.get("logic_roles", [])
if isinstance(metadata.get("logic_roles", []), list)
else [metadata.get("logic_roles", "")]
)
if _normalize(item) in set(injection_planner_runtime.LOGIC_ROLES)
]
if category in set(injection_planner_runtime.LOGIC_ROLES):
logic_roles.append(category)
if _normalize(metadata.get("writeback_class", "")) in set(injection_planner_runtime.LOGIC_ROLES):
logic_roles.append(_normalize(metadata.get("writeback_class", "")))
if not logic_roles:
logic_roles = ["negative"] if _normalize(hit.state) in {"stale", "superseded", "false"} else ["evidence"]
evidence_snippet_role = _normalize(metadata.get("evidence_snippet_role", ""))
if bool(metadata.get("profile_layer")) or category == "profile" or source_kind.endswith("profile"):
layer = "profile"
elif category == "time" or float(metadata.get("temporal_score", 0.0) or 0.0) > 0.0:
layer = "temporal"
elif "resource" in logic_roles or _clean_text(metadata.get("resource_key", "")):
layer = "resource"
elif (
source_kind in {"path_tunnel", "path_support", "public_dialog_path"}
or evidence_snippet_role in {"selected_path_support", "path_tunnel_support"}
or bool(metadata.get("path_tunnel_node"))
):
layer = "path_tunnel"
elif bool(metadata.get("topic_bucket_rerank")) or bool(metadata.get("topic_bucket_dialogue_tunnel_allowed")):
layer = "topic_tunnel"
else:
layer = "event"
normalized_state = _normalize(hit.state)
if normalized_state in {"stale", "superseded", "false"}:
temporal_state = "superseded"
elif bool(metadata.get("topic_bucket_current_subject_preserved")) or bool(metadata.get("current_subject_protected")):
temporal_state = "current"
elif current_turn_index and int(hit.turn_index or 0) and current_turn_index - int(hit.turn_index) <= 1:
temporal_state = "current"
elif normalized_state == "active":
temporal_state = "stable"
else:
temporal_state = "historical"
query_tokens = set(_tokenize(query))
hit_tokens = set(_tokenize(" ".join([hit.value, " ".join(hit.anchors), _clean_text(metadata.get("topic_label", ""))])))
overlap = len(query_tokens & hit_tokens) / max(1, len(query_tokens | hit_tokens))
candidate_id = _clean_text(hit.memory_id) or f"hit:{index}"
topic_label = _clean_text(metadata.get("topic_label", "")) or _clean_text(metadata.get("topic_bucket_query_label", ""))
profile_key = _clean_text(metadata.get("profile_type", "")) or _clean_text(metadata.get("semantic_slot", ""))
if not profile_key and layer in {"profile", "temporal", "topic_tunnel"}:
profile_key = topic_label
semantic_similarity = max(
0.0,
min(
1.0,
float(
metadata.get(
"semantic_similarity",
metadata.get(
"answer_window_semantic_similarity",
metadata.get(
"embedder_similarity",
metadata.get("dense_similarity", metadata.get("bge_m3_similarity", hit.score)),
),
),
)
or 0.0
),
),
)
return {
"id": candidate_id,
"text": hit.value,
"summary": _clean_text(metadata.get("source_turn_text", "")) or _clean_text(metadata.get("event_summary", "")),
"topic": topic_label,
"profile_key": profile_key,
"resource_key": _clean_text(metadata.get("resource_key", "")) or (_clean_text(hit.slot_key) if layer == "resource" else ""),
"layer": layer,
"temporal_state": temporal_state,
"logic_roles": _dedupe(logic_roles),
"query_overlap": round(float(overlap), 6),
"retrieval_score": max(0.0, min(1.0, float(hit.score or 0.0))),
"graph_score": max(
0.0,
min(
1.0,
float(
metadata.get(
"event_score",
metadata.get("recall_score", metadata.get("hybrid_score", hit.score)),
)
or 0.0
),
),
),
"tunnel_score": max(
0.0,
min(
1.0,
max(
float(metadata.get("path_tunnel_support_score", 0.0) or 0.0),
float(metadata.get("event_tunnel_support_score", 0.0) or 0.0),
float(metadata.get("path_chain_extension_delta_score", 0.0) or 0.0),
),
),
),
"topic_similarity": max(0.0, min(1.0, float(metadata.get("topic_bucket_overlap", 0.0) or 0.0))),
"semantic_similarity": semantic_similarity,
"confidence": max(0.0, min(1.0, float(metadata.get("confidence", hit.score or 0.0) or 0.0))),
"rank_score": round(1.0 / float(index + 1), 6),
"age_turns": max(0, int(current_turn_index) - int(hit.turn_index or 0)) if current_turn_index else 0,
"branch_depth": max(0, len(list(metadata.get("selected_path_ids", []) or []))),
"contradicts_current": normalized_state in {"stale", "superseded", "false"} or bool(metadata.get("contradicts_current")),
"is_current": temporal_state == "current",
}
def _apply_injection_planner_to_hits(
self,
query: str,
hits: Sequence[MemoryHit],
*,
top_k: int,
) -> Dict[str, Any]:
normalized_mode = _normalize(self.injection_planner_mode)
base_metadata: Dict[str, Any] = {
"injection_planner_enabled": False,
"injection_planner_mode": normalized_mode or "observe",
}
if normalized_mode in _INJECTION_PLANNER_DISABLED_MODES:
base_metadata["injection_planner_reason"] = "disabled"
return {"hits": list(hits), "metadata": base_metadata}
if not hits:
base_metadata["injection_planner_reason"] = "no_hits"
return {"hits": list(hits), "metadata": base_metadata}
model = self._load_injection_planner()
if model is None or self._injection_planner_config is None:
base_metadata.update(
{
"injection_planner_reason": self._injection_planner_error or "load_failed",
"injection_planner_model_path": self._injection_planner_resolved_path,
}
)
return {"hits": list(hits), "metadata": base_metadata}
torch_module = getattr(injection_planner_runtime, "torch", None)
if torch_module is None:
base_metadata["injection_planner_reason"] = "torch_unavailable"
return {"hits": list(hits), "metadata": base_metadata}
current_turn_index = max(
[int(getattr(self.graph, "turn_index", 0) or 0), *[int(hit.turn_index or 0) for hit in hits]],
default=0,
)
candidates = [
self._injection_planner_candidate_from_hit(query, hit, index=index, current_turn_index=current_turn_index)
for index, hit in enumerate(hits)
]
row = {"id": "runtime_injection_plan", "query": query, "candidates": candidates, "gold": {}}
try:
dataset = injection_planner_runtime.InjectionPlannerDataset([row], self._injection_planner_config)
batch = injection_planner_runtime.collate_injection_batch([dataset[0]])
device = next(model.parameters()).device
model_batch = {
key: value.to(device) if hasattr(value, "to") else value
for key, value in dict(batch).items()
}
with torch_module.no_grad():
outputs = model(model_batch["features"], model_batch["valid_mask"])
selection_scores = torch_module.sigmoid(outputs["selection_logits"])[0].detach().cpu().tolist()
should_inject_score = float(torch_module.sigmoid(outputs["should_inject_logits"])[0].detach().cpu().item())
mode_probs = torch_module.softmax(outputs["injection_mode_logits"], dim=-1)[0].detach().cpu()
mode_index = int(torch_module.argmax(mode_probs).item())
temporal_indices = torch_module.argmax(outputs["temporal_logits"], dim=-1)[0].detach().cpu().tolist()
logic_scores = torch_module.sigmoid(outputs["logic_logits"])[0].detach().cpu().tolist()
if bool(self._injection_planner_evidence_role_supported) and "evidence_role_logits" in outputs:
evidence_role_indices = torch_module.argmax(outputs["evidence_role_logits"], dim=-1)[0].detach().cpu().tolist()
else:
evidence_role_indices = [
injection_planner_runtime.EVIDENCE_ROLES.index("direct_answer")
for _ in candidates
]
except Exception as exc:
base_metadata.update(
{
"injection_planner_reason": f"inference_failed: {exc}",
"injection_planner_model_path": self._injection_planner_resolved_path,
}
)
return {"hits": list(hits), "metadata": base_metadata}
thresholds = {
"selection_threshold": float(self._injection_planner_thresholds.get("selection_threshold", 0.5)),
"row_threshold": float(self._injection_planner_thresholds.get("row_threshold", 0.5)),
"logic_threshold": float(self._injection_planner_thresholds.get("logic_threshold", 0.5)),
}
injection_mode = injection_planner_runtime.INJECTION_MODES[mode_index]
row_allows_injection = should_inject_score >= thresholds["row_threshold"] and injection_mode != "none"
predictions: Dict[str, Dict[str, Any]] = {}
for index, candidate in enumerate(candidates):
logic_roles = [
role
for role, score in zip(injection_planner_runtime.LOGIC_ROLES, logic_scores[index])
if float(score) >= thresholds["logic_threshold"]
]
evidence_role = injection_planner_runtime.EVIDENCE_ROLES[int(evidence_role_indices[index])]
role_allows_selection = evidence_role not in {"noise", "negative_evidence"}
predictions[candidate["id"]] = {
"selection_score": float(selection_scores[index]),
"selected": bool(
row_allows_injection
and role_allows_selection
and float(selection_scores[index]) >= thresholds["selection_threshold"]
),
"temporal_state": injection_planner_runtime.TEMPORAL_STATES[int(temporal_indices[index])],
"evidence_role": evidence_role,
"logic_roles": logic_roles or ["evidence"],
"candidate_layer": candidate.get("layer", ""),
"role_allows_selection": bool(role_allows_selection),
}
annotated_hits: List[MemoryHit] = []
for index, hit in enumerate(hits):
candidate_id = _clean_text(hit.memory_id) or f"hit:{index}"
prediction = predictions.get(candidate_id, {})
metadata = dict(hit.metadata or {})
metadata.update(
{
"injection_planner_enabled": True,
"injection_planner_mode": normalized_mode or "observe",
"injection_planner_model_path": self._injection_planner_resolved_path,
"injection_planner_candidate_id": candidate_id,
"injection_planner_score": round(float(prediction.get("selection_score", 0.0)), 6),
"injection_planner_selected": bool(prediction.get("selected", False)),
"injection_planner_temporal_state": _clean_text(prediction.get("temporal_state", "")),
"injection_planner_evidence_role": _clean_text(prediction.get("evidence_role", "")),
"injection_planner_role_allows_selection": bool(prediction.get("role_allows_selection", False)),
"injection_planner_logic_roles": list(prediction.get("logic_roles", []) or []),
"injection_planner_candidate_layer": _clean_text(prediction.get("candidate_layer", "")),
"injection_planner_injection_mode": injection_mode,
"injection_planner_should_inject_score": round(float(should_inject_score), 6),
"injection_planner_thresholds": dict(thresholds),
"injection_planner_evidence_role_supported": bool(self._injection_planner_evidence_role_supported),
}
)
planner_score = float(prediction.get("selection_score", 0.0))
next_score = max(float(hit.score), planner_score) if bool(prediction.get("selected", False)) else float(hit.score)
annotated_hits.append(
MemoryHit(
memory_id=hit.memory_id,
category=hit.category,
value=hit.value,
relation=hit.relation,
anchors=list(hit.anchors),
score=round(next_score, 6),
source_kind=hit.source_kind,
slot_key=hit.slot_key,
state=hit.state,
turn_index=int(hit.turn_index),
metadata=metadata,
)
)
selected_count = sum(1 for item in annotated_hits if bool((item.metadata or {}).get("injection_planner_selected")))
if normalized_mode in _INJECTION_PLANNER_FORCE_MODES and selected_count:
selected_ids = {hit.memory_id for hit in annotated_hits if bool((hit.metadata or {}).get("injection_planner_selected"))}
planned_hits = [hit for hit in annotated_hits if hit.memory_id in selected_ids]
elif normalized_mode in _INJECTION_PLANNER_GUIDED_MODES and selected_count:
planned_hits = sorted(
annotated_hits,
key=lambda hit: (
not bool((hit.metadata or {}).get("injection_planner_selected")),
-float((hit.metadata or {}).get("injection_planner_score", 0.0) or 0.0),
-float(hit.score or 0.0),
),
)
else:
planned_hits = annotated_hits
base_metadata.update(
{
"injection_planner_enabled": True,
"injection_planner_reason": "ok",
"injection_planner_model_path": self._injection_planner_resolved_path,
"injection_planner_candidate_count": len(candidates),
"injection_planner_selected_count": int(selected_count),
"injection_planner_should_inject_score": round(float(should_inject_score), 6),
"injection_planner_injection_mode": injection_mode,
"injection_planner_thresholds": dict(thresholds),
"injection_planner_evidence_role_supported": bool(self._injection_planner_evidence_role_supported),
"injection_planner_guided": bool(normalized_mode in _INJECTION_PLANNER_GUIDED_MODES | _INJECTION_PLANNER_FORCE_MODES),
"injection_planner_prediction_ids": [
candidate_id
for candidate_id, prediction in predictions.items()
if bool(prediction.get("selected", False))
],
}
)
return {"hits": planned_hits[: max(len(planned_hits), top_k)], "metadata": base_metadata}
def _node_scorer(self) -> LoadedNodeMemoryScorer | None:
if self.retrieval_mode != "hybrid_node_scored":
return None
if self._loaded_node_scorer is not None:
return self._loaded_node_scorer
if not self.node_model_path:
self._node_scorer_error = "node_model_path_missing"
return None
try:
self._loaded_node_scorer = LoadedNodeMemoryScorer(
node_model_path=Path(self.node_model_path),
path_model_path=Path(self.path_model_path) if self.path_model_path else None,
device=self.node_model_device or None,
)
except Exception as exc:
self._node_scorer_error = str(exc)
self._loaded_node_scorer = None
return self._loaded_node_scorer
def _apply_writeback_promotions(self, stored_ids: Sequence[str]) -> List[Dict[str, Any]]:
promotion_events: List[Dict[str, Any]] = []
for memory_id in stored_ids:
record = self.graph.records_by_id.get(memory_id)
if record is None or not isinstance(record.metadata, dict):
continue
if _normalize(record.metadata.get("memory_role", "")) != "assistant" or _normalize(record.metadata.get("authority", "")) != "derived":
continue
canonical_slot_key = _clean_text(record.metadata.get("canonical_slot_key", ""))
writeback_class = _clean_text(record.metadata.get("writeback_class", ""))
if not canonical_slot_key or not writeback_class:
continue
support_refs = self._support_ref_union(record.metadata)
confidence = float(record.confidence or 0.0)
same_records = [
item
for item in self.graph.records_by_id.values()
if isinstance(item.metadata, dict)
and _normalize(item.metadata.get("memory_role", "")) == "assistant"
and _normalize(item.metadata.get("canonical_slot_key", "")) == _normalize(canonical_slot_key)
and _normalize(item.metadata.get("writeback_class", "")) == _normalize(writeback_class)
and _normalize(item.value) == _normalize(record.value)
]
qualifying = [item for item in same_records if float(item.confidence or 0.0) >= 0.9]
distinct_answers = {
_clean_text(answer_id)
for item in qualifying
for answer_id in [*list(item.metadata.get("origin_answer_ids", []) or []), _clean_text(item.metadata.get("origin_answer_id", ""))]
if _clean_text(answer_id)
}
aggregated_support = set()
for item in same_records:
aggregated_support.update(self._support_ref_union(item.metadata))
fast_promotion = writeback_class in {"fact", "state_change"} and confidence >= 0.97 and len(support_refs) >= 3
standard_promotion = len(distinct_answers) >= 2 and len(aggregated_support) >= 2
if not (fast_promotion or standard_promotion):
record.metadata["promotion_state"] = "candidate"
continue
source_head = self._source_head_for_canonical(canonical_slot_key)
blocked_conflict = source_head is not None and _normalize(source_head.value) != _normalize(record.value)
promoted_slot = f"promoted.{canonical_slot_key}"
promoted_metadata = {
**dict(record.metadata or {}),
"memory_role": "assistant",
"authority": "promoted",
"canonical_slot_key": canonical_slot_key,
"writeback_class": writeback_class,
"promotion_state": "blocked_conflict" if blocked_conflict else "promoted",
"support_memory_ids": sorted({*list(record.metadata.get("support_memory_ids", []) or []), *[ref for ref in aggregated_support if ref.startswith("fact:") is False and ref.startswith("path:") is False]}),
"support_fact_refs": sorted({*list(record.metadata.get("support_fact_refs", []) or []), *[ref for ref in aggregated_support if ref.startswith("fact:")]}),
"support_path_refs": sorted({*list(record.metadata.get("support_path_refs", []) or []), *[ref for ref in aggregated_support if ref.startswith("path:")]}),
}
promoted_record = SessionMemoryRecordV2(
memory_id=f"{promoted_slot}:{record.turn_index}:promoted",
category=record.category,
slot_key=promoted_slot,
value=record.value,
relation=record.relation,
anchor_concepts=list(record.anchor_concepts),
evidence_anchors=list(record.evidence_anchors),
salience=max(float(record.salience), 0.78),
confidence=max(float(record.confidence), 0.9),
source_kind=f"promoted_{record.source_kind}",
turn_index=int(record.turn_index),
state="active",
metadata=promoted_metadata,
)
promoted_ids = self.graph.add_records([promoted_record])
record.metadata["promotion_state"] = "blocked_conflict" if blocked_conflict else "promoted"
promotion_events.append(
{
"source_memory_id": record.memory_id,
"promoted_record_ids": list(promoted_ids),
"canonical_slot_key": canonical_slot_key,
"writeback_class": writeback_class,
"promotion_state": record.metadata["promotion_state"],
"blocked_conflict": bool(blocked_conflict),
}
)
return promotion_events
def _source_head_for_canonical(self, canonical_slot_key: str) -> SessionMemoryRecordV2 | None:
head_id = self.graph.slot_heads.get(canonical_slot_key)
if not head_id:
return None
record = self.graph.records_by_id.get(head_id)
if record is None or not isinstance(record.metadata, dict):
return None
if _normalize(record.metadata.get("memory_role", "")) != "user" or _normalize(record.metadata.get("authority", "")) != "source":
return None
return record
def _support_ref_union(self, metadata: Dict[str, Any]) -> set[str]:
return {
*[_clean_text(item) for item in list(metadata.get("support_memory_ids", []) or []) if _clean_text(item)],
*[_clean_text(item) for item in list(metadata.get("support_fact_refs", []) or []) if _clean_text(item)],
*[_clean_text(item) for item in list(metadata.get("support_path_refs", []) or []) if _clean_text(item)],
}
def _hybrid_node_scored_hits(
self,
query: str,
hits: Sequence[MemoryHit],
*,
top_k: int,
public_hits: Sequence[MemoryHit] | None = None,
) -> Dict[str, Any]:
scorer = self._node_scorer()
if scorer is None:
return {
"hits": list(hits),
"metadata": {
"retrieval_mode": "heuristic",
"hybrid_enabled": False,
"hybrid_error": self._node_scorer_error,
},
}
source_hits = _learnable_graph_hits(self.graph)
if not source_hits:
return {
"hits": list(hits),
"metadata": {
"retrieval_mode": "heuristic",
"hybrid_enabled": False,
"hybrid_error": "no_learnable_hits",
},
}
has_public_hits = any(
_is_public_dialog_hit(hit)
and not bool((hit.metadata or {}).get("profile_layer"))
for hit in source_hits
)
has_generic_hits = any(not _is_public_dialog_hit(hit) for hit in source_hits)
hybrid_source = "mixed_full_graph"
if has_public_hits and not has_generic_hits:
hybrid_source = "public_full_graph"
elif has_generic_hits and not has_public_hits:
hybrid_source = "generic_full_graph"
runtime_graph = _build_runtime_graph_from_hits(query, source_hits)
grouped_hits = dict(runtime_graph.get("grouped_hits", {}) or {})
profile_first_payload = _profile_first_hybrid_rescue(
self.graph,
query,
grouped_hits=grouped_hits,
top_k=top_k,
)
profile_first_hits = list(profile_first_payload.get("hits", []) or [])
profile_first_event_ids = list(profile_first_payload.get("event_ids", []) or [])
profile_first_memory_ids = list(profile_first_payload.get("memory_ids", []) or [])
candidate_event_ids = sorted(grouped_hits.keys())
if not candidate_event_ids:
return {
"hits": list(hits),
"metadata": {
"retrieval_mode": "heuristic",
"hybrid_enabled": False,
"hybrid_error": "no_event_candidates",
},
}
question_analysis = extract_question_features(query)
hybrid_candidate_limit = min(
len(candidate_event_ids),
max(
_HYBRID_SELECTED_EVENT_FLOOR,
int(self.candidate_event_k) * 4,
int(self.support_path_k) * 4,
int(top_k) * 4,
),
)
embedder_pre_recall_mode = _normalize(getattr(self, "embedder_pre_recall_mode", ""))
embedder_pre_recall_enabled = embedder_pre_recall_mode not in _EMBEDDER_INDEX_DISABLED_MODES
embedder_pre_recall_index_mode = self.embedder_index_recall_mode
if embedder_pre_recall_mode not in {"1", "true", "yes", "on", "auto", "seed", "candidate", "candidates"}:
embedder_pre_recall_index_mode = embedder_pre_recall_mode
pre_embedder_index_payload: Dict[str, Any] = {
"event_ids": [],
"metadata": {
"embedder_pre_recall_enabled": False,
"embedder_pre_recall_mode": embedder_pre_recall_mode or "off",
"embedder_pre_recall_index_mode": embedder_pre_recall_index_mode or "off",
"embedder_pre_recall_event_ids": [],
},
}
pre_embedder_event_ids: List[str] = []
pre_candidate_event_ids: List[str] = []
if embedder_pre_recall_enabled:
pre_embedder_index_payload = _embedder_index_recall_event_ids(
query,
grouped_hits=grouped_hits,
mode=embedder_pre_recall_index_mode,
limit=self.embedder_pre_recall_k or self.embedder_index_recall_k or hybrid_candidate_limit,
max_terms=self.write_embedder_index_max_terms,
)
pre_embedder_event_ids = list(pre_embedder_index_payload.get("event_ids", []) or [])
pre_candidate_event_ids = _bounded_event_id_union(
pre_embedder_event_ids,
max_items=hybrid_candidate_limit,
)
pre_score_kwargs = {
"graph": runtime_graph,
"question": query,
"question_features": question_analysis,
"rerank_top_k": self.candidate_event_k,
"event_rerank_mode": self.event_rerank_mode,
"matrix_event_top_k": self.matrix_event_top_k,
"support_path_k": self.support_path_k,
"top_k": top_k,
}
if pre_candidate_event_ids:
pre_score_kwargs["candidate_event_ids"] = pre_candidate_event_ids
scored = _call_with_supported_kwargs(
scorer.score_runtime,
**pre_score_kwargs,
)
memory_router_decision = _memory_router_decision(
scored,
mode=self.memory_router_mode,
threshold=self.memory_router_threshold,
margin=self.memory_router_margin,
)
profile_first_router_suppressed = False
if not _memory_router_allows(memory_router_decision, "profile", "resource"):
profile_first_hits = []
profile_first_event_ids = []
profile_first_memory_ids = []
profile_first_router_suppressed = True
initial_recall_event_scores = dict(scored.get("recall_event_scores", {}) or {})
model_recall_event_ids = _bounded_event_id_union(
list(scored.get("recall_event_ids", []) or []),
[
event_id
for event_id, _ in sorted(
initial_recall_event_scores.items(),
key=lambda item: (-float(item[1]), item[0]),
)
],
list(scored.get("rerank_candidate_event_ids", []) or []),
max_items=max(1, len(candidate_event_ids)),
)
learned_recall_event_ids = list(model_recall_event_ids)
symbolic_recall_event_ids = _symbolic_recall_event_ids(
query,
runtime_graph,
grouped_hits=grouped_hits,
limit=hybrid_candidate_limit,
)
symbolic_recall_event_ids = _bounded_event_id_union(
profile_first_event_ids,
symbolic_recall_event_ids,
max_items=hybrid_candidate_limit,
)
if pre_embedder_event_ids:
embedder_index_payload = pre_embedder_index_payload
else:
embedder_index_payload = _embedder_index_recall_event_ids(
query,
grouped_hits=grouped_hits,
mode=self.embedder_index_recall_mode,
limit=self.embedder_index_recall_k or hybrid_candidate_limit,
max_terms=self.write_embedder_index_max_terms,
)
embedder_index_event_ids = list(embedder_index_payload.get("event_ids", []) or [])
embedder_index_metadata = dict(embedder_index_payload.get("metadata", {}) or {})
embedder_index_metadata.update(
{
"embedder_pre_recall_enabled": bool(pre_candidate_event_ids),
"embedder_pre_recall_mode": embedder_pre_recall_mode or "off",
"embedder_pre_recall_index_mode": embedder_pre_recall_index_mode or "off",
"embedder_pre_recall_event_ids": list(pre_embedder_event_ids),
"embedder_pre_recall_candidate_event_ids": list(pre_candidate_event_ids),
"embedder_pre_recall_candidate_count": int(len(pre_candidate_event_ids)),
}
)
hybrid_candidate_event_ids = _bounded_event_id_union(
profile_first_event_ids,
embedder_index_event_ids,
learned_recall_event_ids,
symbolic_recall_event_ids,
max_items=hybrid_candidate_limit,
)
learned_candidate_event_ids = _bounded_event_id_union(
learned_recall_event_ids,
max_items=hybrid_candidate_limit,
)
learned_recall_event_id_set = set(learned_candidate_event_ids)
hybrid_candidate_union_added_event_ids = [
event_id
for event_id in hybrid_candidate_event_ids
if event_id not in learned_recall_event_id_set
]
hybrid_candidate_union_priority_changed = list(hybrid_candidate_event_ids) != list(learned_candidate_event_ids)
hybrid_candidate_union_rescored = False
if hybrid_candidate_union_added_event_ids or (embedder_index_event_ids and hybrid_candidate_union_priority_changed):
scored = _call_with_supported_kwargs(
scorer.score_runtime,
graph=runtime_graph,
question=query,
question_features=question_analysis,
candidate_event_ids=hybrid_candidate_event_ids,
rerank_top_k=self.candidate_event_k,
event_rerank_mode=self.event_rerank_mode,
matrix_event_top_k=self.matrix_event_top_k,
support_path_k=self.support_path_k,
top_k=top_k,
)
hybrid_candidate_union_rescored = True
memory_router_decision = _memory_router_decision(
scored,
mode=self.memory_router_mode,
threshold=self.memory_router_threshold,
margin=self.memory_router_margin,
)
recall_event_scores = dict(scored.get("recall_event_scores", {}) or {})
rerank_candidate_event_ids = list(scored.get("rerank_candidate_event_ids", []) or [])
base_event_scores = dict(scored.get("base_event_scores", {}) or {})
rerank_event_scores = dict(scored.get("rerank_event_scores", {}) or {})
calibrated_event_scores = dict(scored.get("calibrated_event_scores", {}) or {})
matrix_event_scores = dict(scored.get("matrix_event_scores", {}) or {})
event_fusion_delta_scores = dict(scored.get("event_fusion_delta_scores", {}) or {})
event_tunnel_support_scores = dict(scored.get("event_tunnel_support_scores", {}) or {})
event_tunnel_delta_scores = dict(scored.get("event_tunnel_delta_scores", {}) or {})
tri_maze_event_reverse_scores = dict(scored.get("tri_maze_event_reverse_scores", {}) or {})
tri_maze_event_boundary_scores = dict(scored.get("tri_maze_event_boundary_scores", {}) or {})
tri_maze_event_reverse_relations = dict(scored.get("tri_maze_event_reverse_relations", {}) or {})
matrix_rerank_event_ids = list(scored.get("matrix_rerank_event_ids", []) or [])
matrix_enabled = bool(scored.get("matrix_enabled", False))
rerank_path_scores = dict(scored.get("rerank_path_scores", {}) or {})
matrix_path_scores = dict(scored.get("matrix_path_scores", {}) or {})
tri_maze_path_reverse_scores = dict(scored.get("tri_maze_path_reverse_scores", {}) or {})
tri_maze_path_boundary_scores = dict(scored.get("tri_maze_path_boundary_scores", {}) or {})
tri_maze_path_reverse_relations = dict(scored.get("tri_maze_path_reverse_relations", {}) or {})
matrix_path_rerank_ids = list(scored.get("matrix_path_rerank_ids", []) or [])
matrix_path_enabled = bool(scored.get("matrix_path_enabled", False))
fusion_enabled = bool(scored.get("fusion_enabled", False))
event_fusion_enabled = bool(scored.get("event_fusion_enabled", fusion_enabled))
path_fusion_enabled = bool(scored.get("path_fusion_enabled", fusion_enabled))
event_calibration_enabled = bool(scored.get("event_calibration_enabled", False))
path_calibration_enabled = bool(scored.get("path_calibration_enabled", False))
event_tunnel_enabled = bool(scored.get("event_tunnel_enabled", False))
path_tunnel_enabled = bool(scored.get("path_tunnel_enabled", False))
final_event_fusion_enabled = bool(scored.get("final_event_fusion_enabled", False))
final_path_fusion_enabled = bool(scored.get("final_path_fusion_enabled", False))
decision_fusion_enabled = bool(scored.get("decision_fusion_enabled", False))
decision_score_source = _clean_text(scored.get("decision_score_source", ""))
event_scores = dict(scored.get("event_scores", {}) or {})
base_path_scores = dict(scored.get("base_path_scores", {}) or {})
calibrated_path_scores = dict(scored.get("calibrated_path_scores", {}) or {})
path_fusion_delta_scores = dict(scored.get("path_fusion_delta_scores", {}) or {})
path_tunnel_support_scores = dict(scored.get("path_tunnel_support_scores", {}) or {})
path_tunnel_delta_scores = dict(scored.get("path_tunnel_delta_scores", {}) or {})
path_model_scores = dict(scored.get("path_model_scores", {}) or {})
path_chain_extension_delta_scores = dict(scored.get("path_chain_extension_delta_scores", {}) or {})
path_chain_extended_scores = dict(scored.get("path_chain_extended_scores", {}) or {})
path_chain_extension_enabled = bool(scored.get("path_chain_extension_enabled", False))
answer_type_scores = dict(scored.get("answer_type_scores", {}) or {})
selected_event_ids_from_model = list(scored.get("selected_event_ids", []) or [])
selected_path_ids_from_model = list(scored.get("selected_path_ids", []) or [])
focused_answer_type_from_model = _clean_text(scored.get("focused_answer_type", ""))
path_scores = dict(scored.get("path_scores", {}) or {})
temporal_scores = dict(scored.get("temporal_scores", {}) or {})
runtime_paths = {_clean_text(path.get("id", "")): dict(path) for path in list(runtime_graph.get("paths", []) or [])}
answer_plan_scores_raw = dict(scored.get("answer_plan_scores", {}) or {})
def _answer_plan_score_map(role: str) -> Dict[str, float]:
raw_scores = answer_plan_scores_raw.get(role, {})
if not isinstance(raw_scores, Mapping):
return {}
normalized: Dict[str, float] = {}
for raw_event_id, raw_score in raw_scores.items():
event_id = _clean_text(raw_event_id)
if not event_id:
continue
try:
normalized[event_id] = float(raw_score or 0.0)
except (TypeError, ValueError):
normalized[event_id] = 0.0
return normalized
answer_plan_selected_scores = _answer_plan_score_map("selected")
answer_plan_current_scores = _answer_plan_score_map("current")
answer_plan_historical_scores = _answer_plan_score_map("historical")
answer_plan_suppressed_scores = _answer_plan_score_map("suppressed")
answer_plan_scores = {
"selected": dict(answer_plan_selected_scores),
"current": dict(answer_plan_current_scores),
"historical": dict(answer_plan_historical_scores),
"suppressed": dict(answer_plan_suppressed_scores),
}
try:
answer_plan_event_selection_threshold = float(
os.getenv("TMCRA_ANSWER_PLAN_EVENT_SELECTION_THRESHOLD", "0.50") or 0.50
)
except (TypeError, ValueError):
answer_plan_event_selection_threshold = 0.50
try:
answer_plan_event_selection_top_k = int(os.getenv("TMCRA_ANSWER_PLAN_EVENT_SELECTION_TOP_K", "0") or 0)
except (TypeError, ValueError):
answer_plan_event_selection_top_k = 0
if answer_plan_event_selection_top_k <= 0:
answer_plan_event_selection_top_k = max(top_k, self.support_path_k * 2, _HYBRID_SELECTED_EVENT_FLOOR)
answer_plan_available_event_ids = {
_clean_text(event_id)
for event_id in [
*list(grouped_hits.keys()),
*[_clean_text(path.get("event_id", "")) for path in runtime_paths.values()],
]
if _clean_text(event_id)
}
answer_plan_event_rows: List[tuple[str, float, float, float, float, float]] = []
answer_plan_event_ids = set(answer_plan_selected_scores) | set(answer_plan_current_scores) | set(answer_plan_historical_scores)
for event_id in answer_plan_event_ids:
if answer_plan_available_event_ids and event_id not in answer_plan_available_event_ids:
continue
selected_score = float(answer_plan_selected_scores.get(event_id, 0.0) or 0.0)
current_score = float(answer_plan_current_scores.get(event_id, 0.0) or 0.0)
historical_score = float(answer_plan_historical_scores.get(event_id, 0.0) or 0.0)
suppressed_score = float(answer_plan_suppressed_scores.get(event_id, 0.0) or 0.0)
support_score = max(selected_score, current_score)
adjusted_score = support_score - max(0.0, suppressed_score - support_score) * 0.5
if support_score < answer_plan_event_selection_threshold or suppressed_score > support_score:
continue
answer_plan_event_rows.append(
(event_id, adjusted_score, support_score, selected_score, current_score, historical_score)
)
answer_plan_event_rows.sort(
key=lambda row: (-float(row[1]), -float(row[2]), -float(row[3]), -float(row[4]), row[0])
)
answer_plan_raw_ranked_event_ids = [
event_id for event_id, *_ in answer_plan_event_rows[: max(1, answer_plan_event_selection_top_k)]
]
try:
answer_plan_promotion_min_margin = float(
os.getenv("TMCRA_ANSWER_PLAN_PROMOTION_MIN_MARGIN", "0.02") or 0.02
)
except (TypeError, ValueError):
answer_plan_promotion_min_margin = 0.02
answer_plan_promotion_score_margin = 0.0
if len(answer_plan_event_rows) == 1:
answer_plan_promotion_score_margin = float(answer_plan_event_rows[0][1])
answer_plan_promotion_enabled = True
elif len(answer_plan_event_rows) > 1:
comparison_index = min(len(answer_plan_event_rows) - 1, max(1, min(answer_plan_event_selection_top_k, 5) - 1))
answer_plan_promotion_score_margin = float(answer_plan_event_rows[0][1]) - float(answer_plan_event_rows[comparison_index][1])
answer_plan_promotion_enabled = answer_plan_promotion_score_margin >= answer_plan_promotion_min_margin
else:
answer_plan_promotion_enabled = False
answer_plan_ranked_event_ids = list(answer_plan_raw_ranked_event_ids if answer_plan_promotion_enabled else [])
answer_plan_selected_event_ids = [
event_id
for event_id in answer_plan_ranked_event_ids
if float(answer_plan_selected_scores.get(event_id, 0.0) or 0.0)
>= answer_plan_event_selection_threshold
]
answer_plan_current_event_ids = [
event_id
for event_id in answer_plan_ranked_event_ids
if float(answer_plan_current_scores.get(event_id, 0.0) or 0.0)
>= answer_plan_event_selection_threshold
]
answer_plan_support_scores = {
event_id: round(float(support_score), 6)
for event_id, _, support_score, *_ in answer_plan_event_rows
}
answer_plan_adjusted_scores = {
event_id: round(float(adjusted_score), 6)
for event_id, adjusted_score, *_ in answer_plan_event_rows
}
answer_plan_rank_lookup = {
event_id: rank for rank, event_id in enumerate(answer_plan_ranked_event_ids, start=1)
}
answer_plan_ranked_event_id_set = set(answer_plan_ranked_event_ids)
def _answer_plan_hit_metadata(event_id: str) -> Dict[str, Any]:
clean_event_id = _clean_text(event_id)
selected_score = float(answer_plan_selected_scores.get(clean_event_id, 0.0) or 0.0)
current_score = float(answer_plan_current_scores.get(clean_event_id, 0.0) or 0.0)
historical_score = float(answer_plan_historical_scores.get(clean_event_id, 0.0) or 0.0)
suppressed_score = float(answer_plan_suppressed_scores.get(clean_event_id, 0.0) or 0.0)
support_score = max(selected_score, current_score)
return {
"answer_plan_score": round(float(support_score), 6),
"answer_plan_selected_score": round(float(selected_score), 6),
"answer_plan_current_score": round(float(current_score), 6),
"answer_plan_historical_score": round(float(historical_score), 6),
"answer_plan_suppressed_score": round(float(suppressed_score), 6),
"answer_plan_adjusted_score": round(float(answer_plan_adjusted_scores.get(clean_event_id, 0.0)), 6),
"answer_plan_selected": bool(clean_event_id in answer_plan_ranked_event_id_set),
"answer_plan_rank": int(answer_plan_rank_lookup.get(clean_event_id, 0) or 0),
}
embedder_fusion_mode = _normalize(self.embedder_fusion_mode)
embedder_fusion_enabled = (
embedder_fusion_mode not in _EMBEDDER_INDEX_DISABLED_MODES
and bool(embedder_index_event_ids)
and bool(self.embedder_fusion_top_k > 0)
and bool(self.embedder_fusion_weight > 0.0)
)
embedder_fusion_applied_event_scores: Dict[str, float] = {}
embedder_fusion_boosts: Dict[str, float] = {}
embedder_event_scores = dict(embedder_index_metadata.get("embedder_index_event_scores", {}) or {})
if embedder_fusion_enabled and embedder_event_scores:
ranked_embedder_events = [
(event_id, float(score))
for event_id, score in sorted(
embedder_event_scores.items(),
key=lambda item: (-float(item[1]), item[0]),
)
if _clean_text(event_id)
][: max(1, int(self.embedder_fusion_top_k))]
for rank, (event_id, embedder_score) in enumerate(ranked_embedder_events, start=1):
if embedder_score < float(self.embedder_fusion_score_floor):
continue
rank_bonus = max(0.0, 0.08 - (rank - 1) * 0.006)
boost = min(
float(self.embedder_fusion_max_boost),
(float(self.embedder_fusion_weight) * embedder_score) + rank_bonus,
)
current_score = max(
float(recall_event_scores.get(event_id, 0.0) or 0.0),
float(event_scores.get(event_id, 0.0) or 0.0),
float(base_event_scores.get(event_id, 0.0) or 0.0),
float(rerank_event_scores.get(event_id, 0.0) or 0.0),
float(calibrated_event_scores.get(event_id, 0.0) or 0.0),
)
fused_score = current_score + boost
recall_event_scores[event_id] = max(float(recall_event_scores.get(event_id, 0.0) or 0.0), fused_score)
event_scores[event_id] = max(float(event_scores.get(event_id, 0.0) or 0.0), fused_score)
base_event_scores[event_id] = max(float(base_event_scores.get(event_id, 0.0) or 0.0), fused_score)
rerank_event_scores[event_id] = max(float(rerank_event_scores.get(event_id, 0.0) or 0.0), fused_score)
calibrated_event_scores[event_id] = max(float(calibrated_event_scores.get(event_id, 0.0) or 0.0), fused_score)
embedder_fusion_applied_event_scores[event_id] = round(fused_score, 6)
embedder_fusion_boosts[event_id] = round(boost, 6)
if embedder_fusion_applied_event_scores:
decision_score_source = f"{decision_score_source or 'learned_decision_fusion'}+embedder_fusion"
recall_event_ids = [
event_id
for event_id, _ in sorted(recall_event_scores.items(), key=lambda item: (-float(item[1]), item[0]))
][: self.candidate_event_k]
if profile_first_event_ids:
profile_first_score_by_event = {
_clean_text((hit.metadata or {}).get("profile_first_hybrid_event_id", "")) or _runtime_event_key(hit): float(hit.score)
for hit in profile_first_hits
}
for rank, event_id in enumerate(profile_first_event_ids, start=1):
if not event_id:
continue
floor = max(0.92, float(profile_first_score_by_event.get(event_id, 0.0))) + max(0.0, 0.04 - (rank * 0.005))
recall_event_scores[event_id] = max(float(recall_event_scores.get(event_id, 0.0)), floor)
event_scores[event_id] = max(float(event_scores.get(event_id, 0.0)), floor)
base_event_scores[event_id] = max(float(base_event_scores.get(event_id, 0.0)), floor)
rerank_event_scores[event_id] = max(float(rerank_event_scores.get(event_id, 0.0)), floor)
calibrated_event_scores[event_id] = max(float(calibrated_event_scores.get(event_id, 0.0)), floor)
for path_id, path in runtime_paths.items():
event_id = _clean_text(path.get("event_id", ""))
path_type = _clean_text(path.get("type", ""))
if event_id not in set(profile_first_event_ids):
continue
if path_type not in {"speaker_event_profile", "speaker_event_source_turn", "speaker_event_status"}:
continue
floor = max(0.90, float(event_scores.get(event_id, 0.0)))
path_scores[path_id] = max(float(path_scores.get(path_id, 0.0)), floor)
base_path_scores[path_id] = max(float(base_path_scores.get(path_id, 0.0)), floor)
calibrated_path_scores[path_id] = max(float(calibrated_path_scores.get(path_id, 0.0)), floor)
path_model_scores[path_id] = max(float(path_model_scores.get(path_id, 0.0)), floor)
selection_consistency_repaired = False
selection_consistency_reason = ""
model_focused_answer_type = _clean_text(focused_answer_type_from_model)
embedder_fusion_selected_event_ids: List[str] = []
embedder_fusion_selected_path_ids: List[str] = []
if decision_fusion_enabled and (selected_path_ids_from_model or selected_event_ids_from_model):
base_selected_path_limit = max(1, min(max(1, self.support_path_k), max(1, top_k)))
selected_path_ids = [
path_id
for path_id in selected_path_ids_from_model
if _clean_text(path_id) in runtime_paths
]
if not selected_path_ids:
selected_path_ids = [
path_id
for path_id, _ in sorted(
((path_id, float(score or 0.0)) for path_id, score in path_scores.items()),
key=lambda item: (-float(item[1]), item[0]),
)
][:base_selected_path_limit]
focused_answer_type = _reconciled_focused_answer_type(question_analysis, answer_type_scores, focused_answer_type_from_model)
tunnel_rescue_path_ids: List[str] = []
tunnel_rescue_pre_filter_path_ids: List[str] = []
path_utility_direct_support_path_ids: List[str] = []
path_utility_contrast_support_path_ids: List[str] = []
path_utility_latent_context_path_ids: List[str] = []
path_utility_drift_noise_path_ids: List[str] = []
path_utility_roles: Dict[str, str] = {}
path_utility_reasons: Dict[str, str] = {}
path_utility_scores: Dict[str, float] = {}
path_utility_overlap_tokens: Dict[str, List[str]] = {}
path_utility_anchor_event_ids: List[str] = []
path_utility_anchor_subject_signatures: List[str] = []
tunnel_rescue_score_threshold = float(self.path_tunnel_rescue_score_floor)
tunnel_rescue_candidate_count = 0
tunnel_rescue_filtered_count = 0
if (
self.path_tunnel_rescue_k > 0
and path_tunnel_support_scores
and _memory_router_allows(memory_router_decision, "path_tunnel", "topic_tunnel")
):
selected_path_id_set = set(selected_path_ids)
candidate_scores = [
float(score or 0.0)
for path_id, score in path_tunnel_support_scores.items()
if _clean_text(path_id) in runtime_paths
]
tunnel_rescue_candidate_count = len(candidate_scores)
if candidate_scores and self.path_tunnel_rescue_min_score_margin > 0.0:
sorted_scores = sorted(candidate_scores)
median_score = sorted_scores[len(sorted_scores) // 2]
tunnel_rescue_score_threshold = max(
tunnel_rescue_score_threshold,
median_score + float(self.path_tunnel_rescue_min_score_margin),
)
event_turn_indices = [
_runtime_event_turn_index_from_id(_clean_text(path.get("event_id", "")))
for path in runtime_paths.values()
]
current_turn_index = max(
[int(getattr(self.graph, "turn_index", 0) or 0), *[turn for turn in event_turn_indices if turn > 0]],
default=0,
)
ranked_tunnel_path_ids = []
for path_id, score in sorted(
(
(path_id, float(score or 0.0))
for path_id, score in path_tunnel_support_scores.items()
if _clean_text(path_id) in runtime_paths
),
key=lambda item: (-float(item[1]), item[0]),
):
if path_id in selected_path_id_set or score < tunnel_rescue_score_threshold:
continue
path_event_turn = _runtime_event_turn_index_from_id(
_clean_text(runtime_paths.get(path_id, {}).get("event_id", ""))
)
path_age = max(0, current_turn_index - path_event_turn) if current_turn_index and path_event_turn else 0
if self.path_tunnel_rescue_min_age > 0 and path_age < self.path_tunnel_rescue_min_age:
continue
ranked_tunnel_path_ids.append(path_id)
tunnel_rescue_filtered_count = len(ranked_tunnel_path_ids)
tunnel_rescue_pre_filter_path_ids = ranked_tunnel_path_ids[: max(self.path_tunnel_rescue_k, self.path_tunnel_rescue_k * 4)]
utility_gate = _path_utility_gate(
tunnel_rescue_pre_filter_path_ids,
query=query,
runtime_graph=runtime_graph,
runtime_paths=runtime_paths,
grouped_hits=grouped_hits,
selected_path_ids=selected_path_ids,
selected_event_ids_from_model=selected_event_ids_from_model,
path_scores=path_scores,
path_tunnel_support_scores=path_tunnel_support_scores,
question_analysis=question_analysis,
focused_answer_type=focused_answer_type,
score_threshold=tunnel_rescue_score_threshold,
limit=self.path_tunnel_rescue_k,
)
tunnel_rescue_path_ids = list(utility_gate.get("injected_path_ids", []) or [])
path_utility_direct_support_path_ids = list(utility_gate.get("direct_support_path_ids", []) or [])
path_utility_contrast_support_path_ids = list(utility_gate.get("contrast_support_path_ids", []) or [])
path_utility_latent_context_path_ids = list(utility_gate.get("latent_context_path_ids", []) or [])
path_utility_drift_noise_path_ids = list(utility_gate.get("drift_noise_path_ids", []) or [])
path_utility_roles = dict(utility_gate.get("roles", {}) or {})
path_utility_reasons = dict(utility_gate.get("reasons", {}) or {})
path_utility_scores = dict(utility_gate.get("scores", {}) or {})
path_utility_overlap_tokens = dict(utility_gate.get("overlap_tokens", {}) or {})
path_utility_anchor_event_ids = list(utility_gate.get("anchor_event_ids", []) or [])
path_utility_anchor_subject_signatures = list(utility_gate.get("anchor_subject_signatures", []) or [])
if tunnel_rescue_path_ids:
selected_path_ids = _dedupe([*selected_path_ids, *tunnel_rescue_path_ids])
effective_selected_path_limit = base_selected_path_limit + len(tunnel_rescue_path_ids)
selected_event_ids = _dedupe(
[
*[
_clean_text(runtime_paths.get(path_id, {}).get("event_id", ""))
for path_id in selected_path_ids
if _clean_text(runtime_paths.get(path_id, {}).get("event_id", ""))
],
*[
_clean_text(event_id)
for event_id in selected_event_ids_from_model
if _clean_text(event_id)
],
]
)
if answer_plan_ranked_event_ids:
selected_event_ids = _dedupe([*selected_event_ids, *answer_plan_ranked_event_ids])
if not selected_event_ids:
selected_event_ids = [
event_id
for event_id, _ in sorted(
((event_id, float(score or 0.0)) for event_id, score in event_scores.items()),
key=lambda item: (-float(item[1]), item[0]),
)
][: max(1, min(self.candidate_event_k, max(top_k, _HYBRID_SELECTED_EVENT_FLOOR)))]
repaired_path_ids, selection_consistency_repaired, selection_consistency_reason = _repair_selected_paths_for_focus(
selected_path_ids,
runtime_paths=runtime_paths,
selected_event_ids=selected_event_ids,
path_scores=path_scores,
event_scores=event_scores,
temporal_scores=temporal_scores,
question_analysis=question_analysis,
answer_type_scores=answer_type_scores,
focused_answer_type=focused_answer_type,
limit=max(1, effective_selected_path_limit),
)
if selection_consistency_repaired:
selected_path_ids = repaired_path_ids
selected_event_ids = _dedupe(
[
*[
_clean_text(runtime_paths.get(path_id, {}).get("event_id", ""))
for path_id in selected_path_ids
if _clean_text(runtime_paths.get(path_id, {}).get("event_id", ""))
],
*[
_clean_text(event_id)
for event_id in selected_event_ids_from_model
if _clean_text(event_id)
],
]
)
if answer_plan_ranked_event_ids:
selected_event_ids = _dedupe([*selected_event_ids, *answer_plan_ranked_event_ids])
if profile_first_event_ids:
selected_event_ids = _dedupe([*profile_first_event_ids, *selected_event_ids])
if embedder_fusion_applied_event_scores and self.embedder_fusion_select_k > 0:
ranked_fusion_event_ids = [
event_id
for event_id in embedder_index_event_ids
if event_id in embedder_fusion_applied_event_scores
][: max(1, int(self.embedder_fusion_select_k))]
selected_event_ids = _dedupe([*ranked_fusion_event_ids, *selected_event_ids])
selected_path_id_set = set(selected_path_ids)
selected_embedder_path_ids: List[str] = []
for event_id in ranked_fusion_event_ids:
candidate_paths = [
(path_id, path)
for path_id, path in runtime_paths.items()
if _clean_text(path.get("event_id", "")) == event_id
]
candidate_paths.sort(
key=lambda item: (
int(_clean_text(item[1].get("type", "")) in {"speaker_event_source_turn", "speaker_event_profile", "speaker_event_status", "speaker_event_time"}),
float(path_scores.get(item[0], 0.0) or 0.0),
item[0],
),
reverse=True,
)
for path_id, _ in candidate_paths:
if path_id in selected_path_id_set:
continue
selected_embedder_path_ids.append(path_id)
selected_path_id_set.add(path_id)
embedder_fusion_selected_path_ids.append(path_id)
break
if selected_embedder_path_ids:
selected_path_ids = _dedupe([*selected_embedder_path_ids, *selected_path_ids])
embedder_fusion_selected_event_ids = list(ranked_fusion_event_ids)
if answer_plan_ranked_event_ids:
selected_event_ids = _dedupe([*selected_event_ids, *answer_plan_ranked_event_ids])
final_hits: List[MemoryHit] = []
seen_memory_ids = set()
selected_event_id_set = set(selected_event_ids)
for path_id in selected_path_ids:
path = runtime_paths.get(path_id, {})
event_id = _clean_text(path.get("event_id", ""))
if event_id not in selected_event_id_set:
continue
path_type = _clean_text(path.get("type", ""))
support_node_id = _path_support_node_id(path)
support_hit = _support_hit_for_path(path_type, grouped_hits.get(event_id, []))
event_hit = _representative_event_hit(
[*grouped_hits.get(event_id, []), *_event_record_hits_from_graph(self.graph, event_id)],
query=query,
)
decision_score = round(float(path_scores.get(path_id, 0.0)), 6)
candidate_hits: List[tuple[MemoryHit | None, float]] = [(support_hit, decision_score)]
if event_hit is not None and (support_hit is None or event_hit.memory_id != support_hit.memory_id):
candidate_hits.append((event_hit, max(0.0, decision_score - 0.0001)))
for raw_hit, hit_score in candidate_hits:
if raw_hit is None or raw_hit.memory_id in seen_memory_ids:
continue
seen_memory_ids.add(raw_hit.memory_id)
recall_score = float(recall_event_scores.get(event_id, event_scores.get(event_id, 0.0)))
event_score = float(event_scores.get(event_id, raw_hit.score))
temporal_score = float(temporal_scores.get(support_node_id, 0.0))
metadata = dict(raw_hit.metadata or {})
metadata.update(
{
"event_id": event_id,
**_answer_plan_hit_metadata(event_id),
"path_id": path_id,
"base_event_score": round(float(base_event_scores.get(event_id, event_score)), 6),
"rerank_event_score": round(float(rerank_event_scores.get(event_id, event_score)), 6),
"calibrated_event_score": round(float(calibrated_event_scores.get(event_id, event_score)), 6),
"matrix_event_score": round(float(matrix_event_scores.get(event_id, 0.0)), 6),
"event_fusion_delta_score": round(float(event_fusion_delta_scores.get(event_id, 0.0)), 6),
"event_tunnel_support_score": round(float(event_tunnel_support_scores.get(event_id, 0.0)), 6),
"event_tunnel_delta_score": round(float(event_tunnel_delta_scores.get(event_id, 0.0)), 6),
"tri_maze_event_reverse_score": round(float(tri_maze_event_reverse_scores.get(event_id, 0.0)), 6),
"tri_maze_event_boundary_score": round(float(tri_maze_event_boundary_scores.get(event_id, 0.0)), 6),
"tri_maze_event_reverse_relation": round(float(tri_maze_event_reverse_relations.get(event_id, 0.0)), 6),
"matrix_enabled": matrix_enabled,
"event_calibration_enabled": event_calibration_enabled,
"path_calibration_enabled": path_calibration_enabled,
"event_tunnel_enabled": event_tunnel_enabled,
"path_tunnel_enabled": path_tunnel_enabled,
"final_event_fusion_enabled": final_event_fusion_enabled,
"final_path_fusion_enabled": final_path_fusion_enabled,
"decision_fusion_enabled": True,
"event_fusion_enabled": event_fusion_enabled,
"path_fusion_enabled": path_fusion_enabled,
"event_score": round(event_score, 6),
"recall_score": round(recall_score, 6),
"path_score": round(float(path_scores.get(path_id, hit_score)), 6),
"base_path_score": round(float(base_path_scores.get(path_id, path_scores.get(path_id, hit_score))), 6),
"calibrated_path_score": round(float(calibrated_path_scores.get(path_id, path_scores.get(path_id, hit_score))), 6),
"path_fusion_delta_score": round(float(path_fusion_delta_scores.get(path_id, 0.0)), 6),
"path_tunnel_support_score": round(float(path_tunnel_support_scores.get(path_id, 0.0)), 6),
"path_tunnel_delta_score": round(float(path_tunnel_delta_scores.get(path_id, 0.0)), 6),
"path_model_score": round(float(path_model_scores.get(path_id, path_scores.get(path_id, hit_score))), 6),
"path_chain_extension_enabled": path_chain_extension_enabled,
"path_chain_extension_delta_score": round(float(path_chain_extension_delta_scores.get(path_id, 0.0)), 6),
"path_chain_extended_score": round(float(path_chain_extended_scores.get(path_id, path_scores.get(path_id, hit_score))), 6),
"tri_maze_path_reverse_score": round(float(tri_maze_path_reverse_scores.get(path_id, 0.0)), 6),
"tri_maze_path_boundary_score": round(float(tri_maze_path_boundary_scores.get(path_id, 0.0)), 6),
"tri_maze_path_reverse_relation": round(float(tri_maze_path_reverse_relations.get(path_id, 0.0)), 6),
"effective_path_score": round(float(path_scores.get(path_id, hit_score)), 6),
"temporal_score": round(temporal_score, 6),
"raw_public_score": round(float(raw_hit.score), 6),
"hybrid_score": round(hit_score, 6),
"hybrid_score_source": decision_score_source or "learned_decision_fusion",
"evidence_snippet_role": "selected_path_support" if raw_hit is support_hit else "selected_path_event",
"selected_event_ids": list(selected_event_ids),
"selected_path_ids": list(selected_path_ids),
"path_tunnel_rescue_enabled": bool(self.path_tunnel_rescue_k > 0),
"path_tunnel_rescue_k": int(self.path_tunnel_rescue_k),
"path_tunnel_rescue_score_floor": round(float(self.path_tunnel_rescue_score_floor), 6),
"path_tunnel_rescue_min_age": int(self.path_tunnel_rescue_min_age),
"path_tunnel_rescue_min_score_margin": round(float(self.path_tunnel_rescue_min_score_margin), 6),
"path_tunnel_rescue_score_threshold": round(float(tunnel_rescue_score_threshold), 6),
"path_tunnel_rescue_candidate_count": int(tunnel_rescue_candidate_count),
"path_tunnel_rescue_filtered_count": int(tunnel_rescue_filtered_count),
"path_tunnel_rescue_path_ids": list(tunnel_rescue_path_ids),
"path_utility_gate_enabled": bool(self.path_tunnel_rescue_k > 0),
"path_utility_pre_filter_path_ids": list(tunnel_rescue_pre_filter_path_ids),
"path_utility_injected_path_ids": list(tunnel_rescue_path_ids),
"path_utility_direct_support_path_ids": list(path_utility_direct_support_path_ids),
"path_utility_contrast_support_path_ids": list(path_utility_contrast_support_path_ids),
"path_utility_latent_context_path_ids": list(path_utility_latent_context_path_ids),
"path_utility_drift_noise_path_ids": list(path_utility_drift_noise_path_ids),
"path_utility_roles": dict(path_utility_roles),
"path_utility_reasons": dict(path_utility_reasons),
"path_utility_scores": dict(path_utility_scores),
"path_utility_overlap_tokens": dict(path_utility_overlap_tokens),
"path_utility_anchor_event_ids": list(path_utility_anchor_event_ids),
"path_utility_anchor_subject_signatures": list(path_utility_anchor_subject_signatures),
"tunnel_recall_pre_filter_count": int(len(tunnel_rescue_pre_filter_path_ids)),
"tunnel_usable_post_filter_count": int(len(tunnel_rescue_path_ids)),
"model_focused_answer_type": model_focused_answer_type,
"selection_consistency_repaired": bool(selection_consistency_repaired),
"selection_consistency_reason": selection_consistency_reason,
}
)
final_hits.append(
MemoryHit(
memory_id=raw_hit.memory_id,
category=raw_hit.category,
value=raw_hit.value,
relation=raw_hit.relation,
anchors=list(raw_hit.anchors),
score=round(hit_score, 6),
source_kind=raw_hit.source_kind,
slot_key=raw_hit.slot_key,
state=raw_hit.state,
turn_index=int(raw_hit.turn_index),
metadata=metadata,
)
)
for event_rank, event_id in enumerate(selected_event_ids, start=1):
event_hit = _representative_event_hit(
[*grouped_hits.get(event_id, []), *_event_record_hits_from_graph(self.graph, event_id)],
query=query,
)
if event_hit is None or event_hit.memory_id in seen_memory_ids:
continue
seen_memory_ids.add(event_hit.memory_id)
recall_score = float(recall_event_scores.get(event_id, event_hit.score))
event_score = float(event_scores.get(event_id, event_hit.score))
metadata = dict(event_hit.metadata or {})
metadata.update(
{
"event_id": event_id,
**_answer_plan_hit_metadata(event_id),
"path_id": "",
"base_event_score": round(float(base_event_scores.get(event_id, event_score)), 6),
"rerank_event_score": round(float(rerank_event_scores.get(event_id, event_score)), 6),
"calibrated_event_score": round(float(calibrated_event_scores.get(event_id, event_score)), 6),
"matrix_event_score": round(float(matrix_event_scores.get(event_id, 0.0)), 6),
"event_fusion_delta_score": round(float(event_fusion_delta_scores.get(event_id, 0.0)), 6),
"event_tunnel_support_score": round(float(event_tunnel_support_scores.get(event_id, 0.0)), 6),
"event_tunnel_delta_score": round(float(event_tunnel_delta_scores.get(event_id, 0.0)), 6),
"tri_maze_event_reverse_score": round(float(tri_maze_event_reverse_scores.get(event_id, 0.0)), 6),
"tri_maze_event_boundary_score": round(float(tri_maze_event_boundary_scores.get(event_id, 0.0)), 6),
"tri_maze_event_reverse_relation": round(float(tri_maze_event_reverse_relations.get(event_id, 0.0)), 6),
"matrix_enabled": matrix_enabled,
"event_calibration_enabled": event_calibration_enabled,
"path_calibration_enabled": path_calibration_enabled,
"event_tunnel_enabled": event_tunnel_enabled,
"path_tunnel_enabled": path_tunnel_enabled,
"final_event_fusion_enabled": final_event_fusion_enabled,
"final_path_fusion_enabled": final_path_fusion_enabled,
"decision_fusion_enabled": True,
"event_fusion_enabled": event_fusion_enabled,
"path_fusion_enabled": path_fusion_enabled,
"event_score": round(event_score, 6),
"recall_score": round(recall_score, 6),
"path_score": 0.0,
"base_path_score": 0.0,
"calibrated_path_score": 0.0,
"path_fusion_delta_score": 0.0,
"path_tunnel_support_score": 0.0,
"path_tunnel_delta_score": 0.0,
"path_model_score": 0.0,
"path_chain_extension_enabled": path_chain_extension_enabled,
"path_chain_extension_delta_score": 0.0,
"path_chain_extended_score": 0.0,
"effective_path_score": 0.0,
"temporal_score": 0.0,
"raw_public_score": round(float(event_hit.score), 6),
"hybrid_score": round(event_score, 6),
"hybrid_score_source": decision_score_source or "learned_final_event_fusion",
"evidence_snippet_role": "selected_event_representative",
"selected_event_rank": int(event_rank),
"selected_event_ids": list(selected_event_ids),
"selected_path_ids": list(selected_path_ids),
"path_tunnel_rescue_enabled": bool(self.path_tunnel_rescue_k > 0),
"path_tunnel_rescue_k": int(self.path_tunnel_rescue_k),
"path_tunnel_rescue_score_floor": round(float(self.path_tunnel_rescue_score_floor), 6),
"path_tunnel_rescue_min_age": int(self.path_tunnel_rescue_min_age),
"path_tunnel_rescue_min_score_margin": round(float(self.path_tunnel_rescue_min_score_margin), 6),
"path_tunnel_rescue_score_threshold": round(float(tunnel_rescue_score_threshold), 6),
"path_tunnel_rescue_candidate_count": int(tunnel_rescue_candidate_count),
"path_tunnel_rescue_filtered_count": int(tunnel_rescue_filtered_count),
"path_tunnel_rescue_path_ids": list(tunnel_rescue_path_ids),
"path_utility_gate_enabled": bool(self.path_tunnel_rescue_k > 0),
"path_utility_pre_filter_path_ids": list(tunnel_rescue_pre_filter_path_ids),
"path_utility_injected_path_ids": list(tunnel_rescue_path_ids),
"path_utility_direct_support_path_ids": list(path_utility_direct_support_path_ids),
"path_utility_contrast_support_path_ids": list(path_utility_contrast_support_path_ids),
"path_utility_latent_context_path_ids": list(path_utility_latent_context_path_ids),
"path_utility_drift_noise_path_ids": list(path_utility_drift_noise_path_ids),
"path_utility_roles": dict(path_utility_roles),
"path_utility_reasons": dict(path_utility_reasons),
"path_utility_scores": dict(path_utility_scores),
"path_utility_overlap_tokens": dict(path_utility_overlap_tokens),
"path_utility_anchor_event_ids": list(path_utility_anchor_event_ids),
"path_utility_anchor_subject_signatures": list(path_utility_anchor_subject_signatures),
"tunnel_recall_pre_filter_count": int(len(tunnel_rescue_pre_filter_path_ids)),
"tunnel_usable_post_filter_count": int(len(tunnel_rescue_path_ids)),
"model_focused_answer_type": model_focused_answer_type,
"selection_consistency_repaired": bool(selection_consistency_repaired),
"selection_consistency_reason": selection_consistency_reason,
}
)
final_hits.append(
MemoryHit(
memory_id=event_hit.memory_id,
category=event_hit.category,
value=event_hit.value,
relation=event_hit.relation,
anchors=list(event_hit.anchors),
score=round(event_score, 6),
source_kind=event_hit.source_kind,
slot_key=event_hit.slot_key,
state=event_hit.state,
turn_index=int(event_hit.turn_index),
metadata=metadata,
)
)
if not final_hits:
final_hits = list(hits)
if profile_first_hits:
final_hits = _inject_profile_first_hits(
final_hits,
profile_first_hits,
selected_event_ids=selected_event_ids,
selected_path_ids=selected_path_ids,
)
final_hits = _coverage_preserving_final_hits(final_hits, selected_event_ids=selected_event_ids, top_k=top_k)
final_hit_event_ids = _event_ids_from_hits(final_hits)
final_missing_selected_event_ids = [event_id for event_id in selected_event_ids if event_id not in set(final_hit_event_ids)]
return {
"hits": final_hits,
"metadata": {
"retrieval_mode": "hybrid_node_scored",
"hybrid_enabled": True,
"hybrid_source": hybrid_source,
**memory_router_decision,
"profile_first_router_suppressed": bool(profile_first_router_suppressed),
"recall_event_ids": list(recall_event_ids),
"learned_recall_event_ids": list(learned_recall_event_ids),
"model_recall_event_ids": list(model_recall_event_ids),
"symbolic_recall_event_ids": list(symbolic_recall_event_ids),
"embedder_index_recall_event_ids": list(embedder_index_event_ids),
**embedder_index_metadata,
"embedder_fusion_mode": embedder_fusion_mode or "off",
"embedder_fusion_enabled": bool(embedder_fusion_enabled),
"embedder_fusion_weight": round(float(self.embedder_fusion_weight), 6),
"embedder_fusion_score_floor": round(float(self.embedder_fusion_score_floor), 6),
"embedder_fusion_top_k": int(self.embedder_fusion_top_k),
"embedder_fusion_select_k": int(self.embedder_fusion_select_k),
"embedder_fusion_max_boost": round(float(self.embedder_fusion_max_boost), 6),
"embedder_fusion_event_scores": dict(embedder_fusion_applied_event_scores),
"embedder_fusion_boosts": dict(embedder_fusion_boosts),
"embedder_fusion_selected_event_ids": list(embedder_fusion_selected_event_ids),
"embedder_fusion_selected_path_ids": list(embedder_fusion_selected_path_ids),
"profile_first_hybrid_enabled": bool(profile_first_event_ids),
"profile_first_event_ids": list(profile_first_event_ids),
"profile_first_memory_ids": list(profile_first_memory_ids),
"hybrid_candidate_event_ids": list(hybrid_candidate_event_ids),
"hybrid_candidate_union_enabled": True,
"hybrid_candidate_union_rescored": bool(hybrid_candidate_union_rescored),
"hybrid_candidate_union_added_event_ids": list(hybrid_candidate_union_added_event_ids),
"hybrid_candidate_union_priority_changed": bool(hybrid_candidate_union_priority_changed),
"rerank_candidate_event_ids": list(rerank_candidate_event_ids),
"base_event_scores": dict(base_event_scores),
"rerank_event_scores": dict(rerank_event_scores),
"calibrated_event_scores": dict(calibrated_event_scores),
"matrix_event_scores": dict(matrix_event_scores),
"event_fusion_delta_scores": dict(event_fusion_delta_scores),
"event_tunnel_support_scores": dict(event_tunnel_support_scores),
"event_tunnel_delta_scores": dict(event_tunnel_delta_scores),
"tri_maze_event_reverse_scores": dict(tri_maze_event_reverse_scores),
"tri_maze_event_boundary_scores": dict(tri_maze_event_boundary_scores),
"tri_maze_event_reverse_relations": dict(tri_maze_event_reverse_relations),
"matrix_rerank_event_ids": list(matrix_rerank_event_ids),
"matrix_enabled": matrix_enabled,
"rerank_path_scores": dict(rerank_path_scores),
"matrix_path_scores": dict(matrix_path_scores),
"tri_maze_path_reverse_scores": dict(tri_maze_path_reverse_scores),
"tri_maze_path_boundary_scores": dict(tri_maze_path_boundary_scores),
"tri_maze_path_reverse_relations": dict(tri_maze_path_reverse_relations),
"matrix_path_rerank_ids": list(matrix_path_rerank_ids),
"matrix_path_enabled": matrix_path_enabled,
"fusion_enabled": fusion_enabled,
"event_calibration_enabled": event_calibration_enabled,
"path_calibration_enabled": path_calibration_enabled,
"event_tunnel_enabled": event_tunnel_enabled,
"path_tunnel_enabled": path_tunnel_enabled,
"final_event_fusion_enabled": final_event_fusion_enabled,
"final_path_fusion_enabled": final_path_fusion_enabled,
"decision_fusion_enabled": True,
"decision_score_source": decision_score_source or "learned_decision_fusion",
"event_fusion_enabled": event_fusion_enabled,
"path_fusion_enabled": path_fusion_enabled,
"selected_event_ids": list(selected_event_ids),
"path_rescue_event_ids": [],
"selected_path_ids": list(selected_path_ids),
"path_tunnel_rescue_enabled": bool(self.path_tunnel_rescue_k > 0),
"path_tunnel_rescue_k": int(self.path_tunnel_rescue_k),
"path_tunnel_rescue_score_floor": round(float(self.path_tunnel_rescue_score_floor), 6),
"path_tunnel_rescue_min_age": int(self.path_tunnel_rescue_min_age),
"path_tunnel_rescue_min_score_margin": round(float(self.path_tunnel_rescue_min_score_margin), 6),
"path_tunnel_rescue_score_threshold": round(float(tunnel_rescue_score_threshold), 6),
"path_tunnel_rescue_candidate_count": int(tunnel_rescue_candidate_count),
"path_tunnel_rescue_filtered_count": int(tunnel_rescue_filtered_count),
"path_tunnel_rescue_path_ids": list(tunnel_rescue_path_ids),
"path_utility_gate_enabled": bool(self.path_tunnel_rescue_k > 0),
"path_utility_pre_filter_path_ids": list(tunnel_rescue_pre_filter_path_ids),
"path_utility_injected_path_ids": list(tunnel_rescue_path_ids),
"path_utility_direct_support_path_ids": list(path_utility_direct_support_path_ids),
"path_utility_contrast_support_path_ids": list(path_utility_contrast_support_path_ids),
"path_utility_latent_context_path_ids": list(path_utility_latent_context_path_ids),
"path_utility_drift_noise_path_ids": list(path_utility_drift_noise_path_ids),
"path_utility_roles": dict(path_utility_roles),
"path_utility_reasons": dict(path_utility_reasons),
"path_utility_scores": dict(path_utility_scores),
"path_utility_overlap_tokens": dict(path_utility_overlap_tokens),
"path_utility_anchor_event_ids": list(path_utility_anchor_event_ids),
"path_utility_anchor_subject_signatures": list(path_utility_anchor_subject_signatures),
"tunnel_recall_pre_filter_count": int(len(tunnel_rescue_pre_filter_path_ids)),
"tunnel_usable_post_filter_count": int(len(tunnel_rescue_path_ids)),
"final_hit_event_ids": list(final_hit_event_ids),
"final_hit_dia_ids": _dia_ids_from_hits(final_hits),
"final_missing_selected_event_ids": list(final_missing_selected_event_ids),
"selected_event_count": int(len(selected_event_ids)),
"selected_path_count": int(len(selected_path_ids)),
"temporal_scores": dict(temporal_scores),
"recall_event_scores": dict(recall_event_scores),
"event_scores": dict(event_scores),
"base_path_scores": dict(base_path_scores),
"calibrated_path_scores": dict(calibrated_path_scores),
"path_fusion_delta_scores": dict(path_fusion_delta_scores),
"path_tunnel_support_scores": dict(path_tunnel_support_scores),
"path_tunnel_delta_scores": dict(path_tunnel_delta_scores),
"path_model_scores": dict(path_model_scores),
"path_chain_extension_enabled": path_chain_extension_enabled,
"path_chain_extension_delta_scores": dict(path_chain_extension_delta_scores),
"path_chain_extended_scores": dict(path_chain_extended_scores),
"effective_path_scores": dict(path_scores),
"path_scores": dict(path_scores),
"answer_type_scores": dict(answer_type_scores),
"answer_plan_scores": dict(answer_plan_scores),
"answer_plan_support_scores": dict(answer_plan_support_scores),
"answer_plan_adjusted_scores": dict(answer_plan_adjusted_scores),
"answer_plan_raw_ranked_event_ids": list(answer_plan_raw_ranked_event_ids),
"answer_plan_ranked_event_ids": list(answer_plan_ranked_event_ids),
"answer_plan_selected_event_ids": list(answer_plan_selected_event_ids),
"answer_plan_current_event_ids": list(answer_plan_current_event_ids),
"answer_plan_promotion_enabled": bool(answer_plan_promotion_enabled),
"answer_plan_promotion_score_margin": round(float(answer_plan_promotion_score_margin), 6),
"answer_plan_promotion_min_margin": round(float(answer_plan_promotion_min_margin), 6),
"answer_plan_event_selection_threshold": round(float(answer_plan_event_selection_threshold), 6),
"answer_plan_event_selection_top_k": int(answer_plan_event_selection_top_k),
"focused_answer_type": focused_answer_type,
"model_focused_answer_type": model_focused_answer_type,
"selection_consistency_repaired": bool(selection_consistency_repaired),
"selection_consistency_reason": selection_consistency_reason,
"preferred_path_types": [],
},
}
dominant_answer_type = _dominant_answer_type(question_analysis, answer_type_scores)
preferred_path_types = _answer_type_preferred_path_types(question_analysis, answer_type_scores)
if bool(memory_router_decision.get("memory_router_guided")):
temporal_focus = _memory_router_allows(memory_router_decision, "temporal")
else:
temporal_focus = dominant_answer_type == "time" or bool(question_analysis.get("is_temporal", False))
router_profile_focus = bool(memory_router_decision.get("memory_router_guided")) and _memory_router_allows(
memory_router_decision,
"profile",
"resource",
)
learned_event_available = bool(event_scores)
learned_path_available = bool(path_scores)
ranked_events = [
event_id
for event_id, _ in sorted(
(
(event_id, event_scores.get(event_id, recall_event_scores.get(event_id, 0.0)))
for event_id in recall_event_ids
),
key=lambda item: (-float(item[1]), item[0]),
)
]
base_selected_event_count = min(
len(ranked_events),
max(
self.support_path_k * 2,
min(self.candidate_event_k, max(top_k, _HYBRID_SELECTED_EVENT_FLOOR)),
),
)
effective_path_scores = dict(path_scores)
if not learned_path_available:
effective_path_scores = {
path_id: _calibrated_path_score(
path=runtime_paths.get(path_id, {}),
base_score=float(score),
temporal_scores=temporal_scores,
question_analysis=question_analysis,
answer_type_scores=answer_type_scores,
)
for path_id, score in path_scores.items()
}
ranked_path_ids_all = [
path_id
for path_id, _ in sorted(
((path_id, float(score)) for path_id, score in effective_path_scores.items()),
key=lambda item: (-float(item[1]), item[0]),
)
]
path_rescue_count = min(
len(ranked_path_ids_all),
max(self.support_path_k * 2, min(self.candidate_event_k, max(1, top_k))),
)
path_rescue_event_ids = _dedupe(
_clean_text(runtime_paths.get(path_id, {}).get("event_id", ""))
for path_id in ranked_path_ids_all[: max(1, path_rescue_count)]
)
selected_event_ids = _dedupe(
[
*path_rescue_event_ids,
*ranked_events[: max(1, base_selected_event_count)],
]
)
if profile_first_event_ids:
selected_event_ids = _dedupe([*profile_first_event_ids, *selected_event_ids])
if answer_plan_ranked_event_ids:
selected_event_ids = _dedupe([*selected_event_ids, *answer_plan_ranked_event_ids])
selected_event_id_set = set(selected_event_ids)
ranked_path_ids = [
path_id
for path_id in ranked_path_ids_all
if _clean_text(runtime_paths.get(path_id, {}).get("event_id", "")) in selected_event_id_set
]
if temporal_focus:
focused_time_path_ids = [
path_id
for path_id in ranked_path_ids
if _clean_text(runtime_paths.get(path_id, {}).get("type", "")) == "speaker_event_time"
]
if focused_time_path_ids:
ranked_path_ids = focused_time_path_ids
path_limit_cap = _HYBRID_SELECTED_PATH_CAP
if temporal_focus:
path_limit_cap = _HYBRID_TEMPORAL_PATH_CAP
elif dominant_answer_type == "profile" or router_profile_focus:
path_limit_cap = _HYBRID_PROFILE_PATH_CAP
selected_path_count = min(
len(ranked_path_ids),
max(1, min(max(1, self.support_path_k), min(max(1, top_k), path_limit_cap))),
)
selected_path_ids = ranked_path_ids[: max(1, selected_path_count)]
final_hits: List[MemoryHit] = []
seen_memory_ids = set()
temporal_event_hits_added = 0
for path_id in selected_path_ids:
path = runtime_paths.get(path_id, {})
event_id = _clean_text(path.get("event_id", ""))
support_node_id = _path_support_node_id(path)
path_type = _clean_text(path.get("type", ""))
support_hit = _support_hit_for_path(path_type, grouped_hits.get(event_id, []))
event_hit = _representative_event_hit(
[*grouped_hits.get(event_id, []), *_event_record_hits_from_graph(self.graph, event_id)],
query=query,
)
hit_pairs: List[tuple[MemoryHit | None, float]] = [(support_hit, path_scores.get(path_id, 0.0))]
if not temporal_focus:
hit_pairs.append((event_hit, path_scores.get(path_id, 0.0)))
elif support_hit is None and event_hit is not None:
hit_pairs.append((event_hit, path_scores.get(path_id, 0.0)))
elif path_type == "speaker_event_time" and event_hit is not None and temporal_event_hits_added < 1:
hit_pairs.append((event_hit, path_scores.get(path_id, 0.0)))
temporal_event_hits_added += 1
for raw_hit, path_score in hit_pairs:
if raw_hit is None or raw_hit.memory_id in seen_memory_ids:
continue
seen_memory_ids.add(raw_hit.memory_id)
recall_score = float(recall_event_scores.get(event_id, event_scores.get(event_id, 0.0)))
event_score = float(event_scores.get(event_id, raw_hit.score))
temporal_score = float(temporal_scores.get(support_node_id, 0.0))
effective_path_score = float(effective_path_scores.get(path_id, path_score))
hybrid_score = round(effective_path_score, 6) if learned_path_available else round(
(0.55 * event_score)
+ (0.20 * float(path_score))
+ (0.15 * recall_score)
+ (0.10 * temporal_score),
6,
)
metadata = dict(raw_hit.metadata or {})
metadata.update(
{
"event_id": event_id,
**_answer_plan_hit_metadata(event_id),
"path_id": path_id,
"base_event_score": round(float(base_event_scores.get(event_id, event_score)), 6),
"rerank_event_score": round(float(rerank_event_scores.get(event_id, event_score)), 6),
"calibrated_event_score": round(float(calibrated_event_scores.get(event_id, event_score)), 6),
"matrix_event_score": round(float(matrix_event_scores.get(event_id, 0.0)), 6),
"event_fusion_delta_score": round(float(event_fusion_delta_scores.get(event_id, 0.0)), 6),
"event_tunnel_support_score": round(float(event_tunnel_support_scores.get(event_id, 0.0)), 6),
"event_tunnel_delta_score": round(float(event_tunnel_delta_scores.get(event_id, 0.0)), 6),
"matrix_enabled": matrix_enabled,
"event_calibration_enabled": event_calibration_enabled,
"path_calibration_enabled": path_calibration_enabled,
"event_tunnel_enabled": event_tunnel_enabled,
"path_tunnel_enabled": path_tunnel_enabled,
"final_event_fusion_enabled": final_event_fusion_enabled,
"final_path_fusion_enabled": final_path_fusion_enabled,
"decision_fusion_enabled": False,
"event_fusion_enabled": event_fusion_enabled,
"path_fusion_enabled": path_fusion_enabled,
"event_score": round(event_score, 6),
"recall_score": round(recall_score, 6),
"path_score": round(float(path_score), 6),
"base_path_score": round(float(base_path_scores.get(path_id, path_score)), 6),
"calibrated_path_score": round(float(calibrated_path_scores.get(path_id, effective_path_score)), 6),
"path_fusion_delta_score": round(float(path_fusion_delta_scores.get(path_id, 0.0)), 6),
"path_tunnel_support_score": round(float(path_tunnel_support_scores.get(path_id, 0.0)), 6),
"path_tunnel_delta_score": round(float(path_tunnel_delta_scores.get(path_id, 0.0)), 6),
"path_model_score": round(float(path_model_scores.get(path_id, path_score)), 6),
"path_chain_extension_enabled": path_chain_extension_enabled,
"path_chain_extension_delta_score": round(float(path_chain_extension_delta_scores.get(path_id, 0.0)), 6),
"path_chain_extended_score": round(float(path_chain_extended_scores.get(path_id, effective_path_score)), 6),
"effective_path_score": round(effective_path_score, 6),
"temporal_score": round(temporal_score, 6),
"raw_public_score": round(float(raw_hit.score), 6),
"hybrid_score": hybrid_score,
"hybrid_score_source": (
"learned_path_fusion"
if path_fusion_enabled
else "learned_path_score"
if learned_path_available
else "heuristic_mix"
),
"evidence_snippet_role": "selected_path_support" if raw_hit is support_hit else "selected_path_event",
"selected_event_ids": list(selected_event_ids),
"selected_path_ids": list(selected_path_ids),
}
)
final_hits.append(
MemoryHit(
memory_id=raw_hit.memory_id,
category=raw_hit.category,
value=raw_hit.value,
relation=raw_hit.relation,
anchors=list(raw_hit.anchors),
score=hybrid_score,
source_kind=raw_hit.source_kind,
slot_key=raw_hit.slot_key,
state=raw_hit.state,
turn_index=int(raw_hit.turn_index),
metadata=metadata,
)
)
for event_rank, event_id in enumerate(selected_event_ids, start=1):
event_hit = _representative_event_hit(
[*grouped_hits.get(event_id, []), *_event_record_hits_from_graph(self.graph, event_id)],
query=query,
)
if event_hit is None or event_hit.memory_id in seen_memory_ids:
continue
seen_memory_ids.add(event_hit.memory_id)
recall_score = float(recall_event_scores.get(event_id, event_hit.score))
event_score = float(event_scores.get(event_id, event_hit.score))
hybrid_score = round(event_score, 6) if learned_event_available else round(
(0.7 * event_score) + (0.3 * recall_score),
6,
)
metadata = dict(event_hit.metadata or {})
metadata.update(
{
"event_id": event_id,
**_answer_plan_hit_metadata(event_id),
"path_id": "",
"base_event_score": round(float(base_event_scores.get(event_id, event_score)), 6),
"rerank_event_score": round(float(rerank_event_scores.get(event_id, event_score)), 6),
"calibrated_event_score": round(float(calibrated_event_scores.get(event_id, event_score)), 6),
"matrix_event_score": round(float(matrix_event_scores.get(event_id, 0.0)), 6),
"event_fusion_delta_score": round(float(event_fusion_delta_scores.get(event_id, 0.0)), 6),
"event_tunnel_support_score": round(float(event_tunnel_support_scores.get(event_id, 0.0)), 6),
"event_tunnel_delta_score": round(float(event_tunnel_delta_scores.get(event_id, 0.0)), 6),
"matrix_enabled": matrix_enabled,
"event_calibration_enabled": event_calibration_enabled,
"path_calibration_enabled": path_calibration_enabled,
"event_tunnel_enabled": event_tunnel_enabled,
"path_tunnel_enabled": path_tunnel_enabled,
"final_event_fusion_enabled": final_event_fusion_enabled,
"final_path_fusion_enabled": final_path_fusion_enabled,
"decision_fusion_enabled": False,
"event_fusion_enabled": event_fusion_enabled,
"path_fusion_enabled": path_fusion_enabled,
"event_score": round(event_score, 6),
"recall_score": round(recall_score, 6),
"path_score": 0.0,
"base_path_score": 0.0,
"calibrated_path_score": 0.0,
"path_fusion_delta_score": 0.0,
"path_tunnel_support_score": 0.0,
"path_tunnel_delta_score": 0.0,
"path_model_score": 0.0,
"path_chain_extension_enabled": path_chain_extension_enabled,
"path_chain_extension_delta_score": 0.0,
"path_chain_extended_score": 0.0,
"effective_path_score": 0.0,
"temporal_score": 0.0,
"raw_public_score": round(float(event_hit.score), 6),
"hybrid_score": hybrid_score,
"hybrid_score_source": (
"learned_event_fusion"
if event_fusion_enabled
else "learned_event_score"
if learned_event_available
else "heuristic_event_mix"
),
"evidence_snippet_role": "selected_event_representative",
"selected_event_rank": int(event_rank),
"selected_event_ids": list(selected_event_ids),
"selected_path_ids": list(selected_path_ids),
}
)
final_hits.append(
MemoryHit(
memory_id=event_hit.memory_id,
category=event_hit.category,
value=event_hit.value,
relation=event_hit.relation,
anchors=list(event_hit.anchors),
score=hybrid_score,
source_kind=event_hit.source_kind,
slot_key=event_hit.slot_key,
state=event_hit.state,
turn_index=int(event_hit.turn_index),
metadata=metadata,
)
)
if not final_hits:
final_hits = list(hits)
if profile_first_hits:
final_hits = _inject_profile_first_hits(
final_hits,
profile_first_hits,
selected_event_ids=selected_event_ids,
selected_path_ids=selected_path_ids,
)
final_hits = _coverage_preserving_final_hits(final_hits, selected_event_ids=selected_event_ids, top_k=top_k)
final_hit_event_ids = _event_ids_from_hits(final_hits)
final_missing_selected_event_ids = [event_id for event_id in selected_event_ids if event_id not in set(final_hit_event_ids)]
return {
"hits": final_hits,
"metadata": {
"retrieval_mode": "hybrid_node_scored",
"hybrid_enabled": True,
"hybrid_source": hybrid_source,
**memory_router_decision,
"profile_first_router_suppressed": bool(profile_first_router_suppressed),
"recall_event_ids": list(recall_event_ids),
"learned_recall_event_ids": list(learned_recall_event_ids),
"model_recall_event_ids": list(model_recall_event_ids),
"symbolic_recall_event_ids": list(symbolic_recall_event_ids),
"embedder_index_recall_event_ids": list(embedder_index_event_ids),
**embedder_index_metadata,
"embedder_fusion_mode": embedder_fusion_mode or "off",
"embedder_fusion_enabled": bool(embedder_fusion_enabled),
"embedder_fusion_weight": round(float(self.embedder_fusion_weight), 6),
"embedder_fusion_score_floor": round(float(self.embedder_fusion_score_floor), 6),
"embedder_fusion_top_k": int(self.embedder_fusion_top_k),
"embedder_fusion_select_k": int(self.embedder_fusion_select_k),
"embedder_fusion_max_boost": round(float(self.embedder_fusion_max_boost), 6),
"embedder_fusion_event_scores": dict(embedder_fusion_applied_event_scores),
"embedder_fusion_boosts": dict(embedder_fusion_boosts),
"embedder_fusion_selected_event_ids": list(embedder_fusion_selected_event_ids),
"embedder_fusion_selected_path_ids": list(embedder_fusion_selected_path_ids),
"profile_first_hybrid_enabled": bool(profile_first_event_ids),
"profile_first_event_ids": list(profile_first_event_ids),
"profile_first_memory_ids": list(profile_first_memory_ids),
"hybrid_candidate_event_ids": list(hybrid_candidate_event_ids),
"hybrid_candidate_union_enabled": True,
"hybrid_candidate_union_rescored": bool(hybrid_candidate_union_rescored),
"hybrid_candidate_union_added_event_ids": list(hybrid_candidate_union_added_event_ids),
"hybrid_candidate_union_priority_changed": bool(hybrid_candidate_union_priority_changed),
"rerank_candidate_event_ids": list(rerank_candidate_event_ids),
"base_event_scores": dict(base_event_scores),
"rerank_event_scores": dict(rerank_event_scores),
"calibrated_event_scores": dict(calibrated_event_scores),
"matrix_event_scores": dict(matrix_event_scores),
"event_fusion_delta_scores": dict(event_fusion_delta_scores),
"event_tunnel_support_scores": dict(event_tunnel_support_scores),
"event_tunnel_delta_scores": dict(event_tunnel_delta_scores),
"tri_maze_event_reverse_scores": dict(tri_maze_event_reverse_scores),
"tri_maze_event_boundary_scores": dict(tri_maze_event_boundary_scores),
"tri_maze_event_reverse_relations": dict(tri_maze_event_reverse_relations),
"matrix_rerank_event_ids": list(matrix_rerank_event_ids),
"matrix_enabled": matrix_enabled,
"rerank_path_scores": dict(rerank_path_scores),
"matrix_path_scores": dict(matrix_path_scores),
"tri_maze_path_reverse_scores": dict(tri_maze_path_reverse_scores),
"tri_maze_path_boundary_scores": dict(tri_maze_path_boundary_scores),
"tri_maze_path_reverse_relations": dict(tri_maze_path_reverse_relations),
"matrix_path_rerank_ids": list(matrix_path_rerank_ids),
"matrix_path_enabled": matrix_path_enabled,
"fusion_enabled": fusion_enabled,
"event_calibration_enabled": event_calibration_enabled,
"path_calibration_enabled": path_calibration_enabled,
"event_tunnel_enabled": event_tunnel_enabled,
"path_tunnel_enabled": path_tunnel_enabled,
"final_event_fusion_enabled": final_event_fusion_enabled,
"final_path_fusion_enabled": final_path_fusion_enabled,
"decision_fusion_enabled": False,
"decision_score_source": "",
"event_fusion_enabled": event_fusion_enabled,
"path_fusion_enabled": path_fusion_enabled,
"selected_event_ids": list(selected_event_ids),
"path_rescue_event_ids": list(path_rescue_event_ids),
"selected_path_ids": list(selected_path_ids),
"final_hit_event_ids": list(final_hit_event_ids),
"final_hit_dia_ids": _dia_ids_from_hits(final_hits),
"final_missing_selected_event_ids": list(final_missing_selected_event_ids),
"selected_event_count": int(len(selected_event_ids)),
"selected_path_count": int(len(selected_path_ids)),
"temporal_scores": dict(temporal_scores),
"recall_event_scores": dict(recall_event_scores),
"event_scores": dict(event_scores),
"base_path_scores": dict(base_path_scores),
"calibrated_path_scores": dict(calibrated_path_scores),
"path_fusion_delta_scores": dict(path_fusion_delta_scores),
"path_tunnel_support_scores": dict(path_tunnel_support_scores),
"path_tunnel_delta_scores": dict(path_tunnel_delta_scores),
"path_model_scores": dict(path_model_scores),
"path_chain_extension_enabled": path_chain_extension_enabled,
"path_chain_extension_delta_scores": dict(path_chain_extension_delta_scores),
"path_chain_extended_scores": dict(path_chain_extended_scores),
"effective_path_scores": dict(effective_path_scores),
"path_scores": dict(path_scores),
"answer_type_scores": dict(answer_type_scores),
"answer_plan_scores": dict(answer_plan_scores),
"answer_plan_support_scores": dict(answer_plan_support_scores),
"answer_plan_adjusted_scores": dict(answer_plan_adjusted_scores),
"answer_plan_raw_ranked_event_ids": list(answer_plan_raw_ranked_event_ids),
"answer_plan_ranked_event_ids": list(answer_plan_ranked_event_ids),
"answer_plan_selected_event_ids": list(answer_plan_selected_event_ids),
"answer_plan_current_event_ids": list(answer_plan_current_event_ids),
"answer_plan_promotion_enabled": bool(answer_plan_promotion_enabled),
"answer_plan_promotion_score_margin": round(float(answer_plan_promotion_score_margin), 6),
"answer_plan_promotion_min_margin": round(float(answer_plan_promotion_min_margin), 6),
"answer_plan_event_selection_threshold": round(float(answer_plan_event_selection_threshold), 6),
"answer_plan_event_selection_top_k": int(answer_plan_event_selection_top_k),
"focused_answer_type": dominant_answer_type,
"preferred_path_types": list(preferred_path_types),
},
}
def retrieve(self, query: str, *, top_k: int = 6) -> MemoryRetrieval:
self._reload_graph()
start = time.perf_counter()
candidate_top_k = max(top_k, self.candidate_event_k if self.retrieval_mode == "hybrid_node_scored" else min(max(top_k, 18), 48))
payload = self.graph.retrieve(query, top_k=candidate_top_k)
hits = [_raw_hit_to_memory_hit(item) for item in payload.get("hits", []) or []]
scored_lookup = {hit.memory_id: hit for hit in hits if hit.memory_id}
active_hits = _restore_hit_scores([_raw_hit_to_memory_hit(item) for item in payload.get("active_hits", []) or []], scored_lookup)
history_hits = _restore_hit_scores([_raw_hit_to_memory_hit(item) for item in payload.get("history_hits", []) or []], scored_lookup)
stale_hits = _restore_hit_scores([_raw_hit_to_memory_hit(item) for item in payload.get("stale_hits", []) or []], scored_lookup)
overwrite_hits = _restore_hit_scores([_raw_hit_to_memory_hit(item) for item in payload.get("overwrite_hits", []) or []], scored_lookup)
false_hits = _restore_hit_scores([_raw_hit_to_memory_hit(item) for item in payload.get("false_hits", []) or []], scored_lookup)
public_hits = _public_graph_hits(self.graph) if self.retrieval_mode == "hybrid_node_scored" else []
hybrid_payload = self._hybrid_node_scored_hits(query, hits, top_k=top_k, public_hits=public_hits)
hits = list(hybrid_payload.get("hits", []) or hits)
hybrid_metadata = dict(hybrid_payload.get("metadata", {}) or {})
memory_router_decision = _memory_router_decision(
hybrid_metadata,
mode=self.memory_router_mode,
threshold=self.memory_router_threshold,
margin=self.memory_router_margin,
)
current_subject_payload = _current_subject_protected_hits(
query=query,
graph=self.graph,
final_hits=hits,
top_k=top_k,
)
hits = list(current_subject_payload.get("hits", hits) or hits)
current_subject_metadata = dict(current_subject_payload.get("metadata", {}) or {})
audit_anchor_payload = _audit_anchor_protected_hits(
query=query,
final_hits=hits,
candidate_hits=list(public_hits) + list(active_hits) + list(history_hits),
metadata={**dict(payload.get("metadata", {}) or {}), **hybrid_metadata, **current_subject_metadata},
top_k=top_k,
)
hits = list(audit_anchor_payload.get("hits", hits) or hits)
audit_anchor_metadata = dict(audit_anchor_payload.get("metadata", {}) or {})
identifier_payload = _identifier_protected_hits(
query=query,
final_hits=hits,
candidate_hits=list(public_hits) + list(active_hits) + list(history_hits),
top_k=top_k,
)
hits = list(identifier_payload.get("hits", hits) or hits)
identifier_metadata = dict(identifier_payload.get("metadata", {}) or {})
if _memory_router_allows(memory_router_decision, "path_tunnel", "topic_tunnel"):
depth_chain_payload = _depth_chain_protected_hits(
query=query,
graph=self.graph,
final_hits=hits,
top_k=top_k,
)
hits = list(depth_chain_payload.get("hits", hits) or hits)
depth_chain_metadata = dict(depth_chain_payload.get("metadata", {}) or {})
else:
depth_chain_metadata = {
"depth_chain_protected_enabled": False,
"depth_chain_router_suppressed": True,
}
if _memory_router_allows(memory_router_decision, "profile", "resource"):
profile_focused_payload = _profile_focused_pack_hits(
self.graph,
query,
hits,
top_k=top_k,
)
hits = list(profile_focused_payload.get("hits", hits) or hits)
profile_focused_metadata = dict(profile_focused_payload.get("metadata", {}) or {})
else:
profile_focused_metadata = {
"profile_focused_pack_enabled": False,
"profile_focused_router_suppressed": True,
}
if _memory_router_allows(memory_router_decision, "topic_tunnel"):
topic_bucket_payload = _topic_bucket_rerank_hits(self.graph, query, hits, top_k=top_k)
topic_bucket_hits = topic_bucket_payload.get("hits", hits)
hits = list(hits if topic_bucket_hits is None else topic_bucket_hits)
topic_bucket_metadata = dict(topic_bucket_payload.get("metadata", {}) or {})
else:
topic_bucket_metadata = {
"topic_bucket_rerank_enabled": False,
"topic_bucket_router_suppressed": True,
}
temporal_runtime_payload = self._temporal_runtime_pack(query)
temporal_evidence_payload = self._apply_temporal_evidence_pack_to_hits(
hits,
temporal_runtime_payload,
top_k=top_k,
)
hits = list(temporal_evidence_payload.get("hits", hits) or hits)
temporal_runtime_metadata = dict(temporal_evidence_payload.get("metadata", {}) or {})
injection_planner_payload = self._apply_injection_planner_to_hits(query, hits, top_k=top_k)
hits = list(injection_planner_payload.get("hits", hits) or hits)
injection_planner_metadata = dict(injection_planner_payload.get("metadata", {}) or {})
facet_query_pack_payload = _facet_query_pack_hits(
self.graph,
query,
hits,
top_k=top_k,
)
hits = list(facet_query_pack_payload.get("hits", hits) or hits)
facet_query_pack_metadata = dict(facet_query_pack_payload.get("metadata", {}) or {})
unit_coverage_mode = _normalize(os.getenv("TMCRA_UNIT_COVERAGE_PACK_MODE", "on"))
if unit_coverage_mode in _MULTI_UNIT_CHAIN_DISABLED_MODES:
unit_coverage_metadata = {
"unit_coverage_pack_enabled": False,
"unit_coverage_reason": "disabled",
}
else:
unit_coverage_payload = _unit_coverage_pack_hits(
self.graph,
query,
hits,
top_k=top_k,
)
hits = list(unit_coverage_payload.get("hits", hits) or hits)
unit_coverage_metadata = dict(unit_coverage_payload.get("metadata", {}) or {})
if _normalize(os.getenv("TMCRA_MULTI_UNIT_CHAIN_SLOT_MODE", "on")) not in _MULTI_UNIT_CHAIN_DISABLED_MODES:
multi_unit_chain_slot_payload = _multi_unit_chain_slot_hits(
self.graph,
query,
hits,
top_k=top_k,
)
hits = list(multi_unit_chain_slot_payload.get("hits", hits) or hits)
multi_unit_chain_slot_metadata = dict(multi_unit_chain_slot_payload.get("metadata", {}) or {})
else:
multi_unit_chain_slot_metadata = {
"multi_unit_chain_slot_enabled": False,
"multi_unit_chain_slot_reason": "disabled",
}
profile_protected_reinserted_count = 0
profile_protected_ids = _dedupe(
[
*list(profile_focused_metadata.get("profile_focused_pack_memory_ids", []) or []),
*list(profile_focused_metadata.get("profile_first_memory_ids", []) or []),
],
max_items=max(1, min(6, int(top_k or 1))),
)
if profile_protected_ids:
existing_hits_by_id = {hit.memory_id: hit for hit in hits if hit.memory_id}
protected_hits: List[MemoryHit] = []
for memory_id in profile_protected_ids:
hit = existing_hits_by_id.get(memory_id)
if hit is None:
record = getattr(self.graph, "records_by_id", {}).get(memory_id)
if record is None:
continue
hit = _memory_hit_from_record(record)
metadata = dict(hit.metadata or {})
metadata.update(
{
"profile_first_hybrid_rescue": True,
"profile_protected_slot": True,
"evidence_snippet_role": "profile_protected_slot",
}
)
protected_hits.append(
MemoryHit(
memory_id=hit.memory_id,
category=hit.category,
value=hit.value,
relation=hit.relation,
anchors=list(hit.anchors),
score=max(float(hit.score), 4.8),
source_kind=hit.source_kind,
slot_key=hit.slot_key,
state=hit.state,
turn_index=int(hit.turn_index),
metadata=metadata,
)
)
if protected_hits:
profile_protected_reinserted_count = len(protected_hits)
seen_profile_ids = {hit.memory_id for hit in protected_hits if hit.memory_id}
hits = [*protected_hits, *[hit for hit in hits if hit.memory_id not in seen_profile_ids]]
embedder_fusion_output_event_ids = [
_clean_text(event_id)
for event_id in list(hybrid_metadata.get("embedder_fusion_selected_event_ids", []) or [])
if _clean_text(event_id)
]
embedder_fusion_output_reordered = False
if embedder_fusion_output_event_ids:
before_order = _event_ids_from_hits(hits)
seen_output_memory_ids = {hit.memory_id for hit in hits}
for event_id in embedder_fusion_output_event_ids:
event_hit = _representative_event_hit(_event_record_hits_from_graph(self.graph, event_id), query=query)
if event_hit is None or event_hit.memory_id in seen_output_memory_ids:
continue
metadata = dict(event_hit.metadata or {})
metadata.update(
{
"event_id": event_id,
"evidence_snippet_role": "embedder_fusion_event_representative",
"embedder_fusion_event_representative": True,
}
)
hits.append(
MemoryHit(
memory_id=event_hit.memory_id,
category=event_hit.category,
value=event_hit.value,
relation=event_hit.relation,
anchors=list(event_hit.anchors),
score=float(event_hit.score),
source_kind=event_hit.source_kind,
slot_key=event_hit.slot_key,
state=event_hit.state,
turn_index=int(event_hit.turn_index),
metadata=metadata,
)
)
seen_output_memory_ids.add(event_hit.memory_id)
hits = _coverage_preserving_final_hits(
hits,
selected_event_ids=_dedupe([*embedder_fusion_output_event_ids, *before_order]),
top_k=top_k,
)
embedder_fusion_output_reordered = before_order != _event_ids_from_hits(hits)
retrieval_context_tokens = int(payload.get("context_token_estimate", _estimate_tokens_from_hits(hits)))
self._last_retrieval_context_tokens = retrieval_context_tokens
result = MemoryRetrieval(
concepts=list(payload.get("concepts", []) or []),
relations=list(payload.get("relations", []) or []),
hits=hits,
active_hits=active_hits,
history_hits=history_hits,
stale_hits=stale_hits,
overwrite_hits=overwrite_hits,
false_hits=false_hits,
retrieval_seconds=time.perf_counter() - start,
context_token_estimate=retrieval_context_tokens,
retrieval_context_token_estimate=retrieval_context_tokens,
metadata={
"query_id": payload.get("query_id", ""),
**dict(payload.get("metadata", {}) or {}),
**hybrid_metadata,
**memory_router_decision,
**current_subject_metadata,
**audit_anchor_metadata,
**identifier_metadata,
**depth_chain_metadata,
**profile_focused_metadata,
**topic_bucket_metadata,
**temporal_runtime_metadata,
**injection_planner_metadata,
**facet_query_pack_metadata,
**unit_coverage_metadata,
**multi_unit_chain_slot_metadata,
"profile_protected_reinserted_count": profile_protected_reinserted_count,
"embedder_fusion_output_event_ids": list(embedder_fusion_output_event_ids),
"embedder_fusion_output_reordered": bool(embedder_fusion_output_reordered),
},
)
self._persist_graph()
return result
def export_dialog_graph(self, *, mode: str = "light") -> Dict[str, Any]:
self._reload_graph()
return self.graph.export_graph(
snapshot_points=(1000, 5000, 10000, 20000, 50000, 100000, 200000, 300000, 500000),
mode=mode,
)
def export_dialog_graph_mermaid(self) -> str:
self._reload_graph()
return self.graph.export_mermaid()
def register_answer_support(self, *, answer_id: str, memory_ids: List[str], query_id: str = "", answer_text: str = "") -> None:
self._reload_graph()
self.graph.register_answer_support(answer_id=answer_id, memory_ids=memory_ids, query_id=query_id, answer_text=answer_text)
self._persist_graph()
def telemetry_snapshot(self) -> Dict[str, Any]:
self._reload_graph()
return self.graph.summary()
def stats(self) -> Dict[str, Any]:
self._reload_graph()
storage = self._storage_breakdown()
return _state_stats(
storage_bytes=storage["storage_bytes"],
retrieval_context_tokens=self._last_retrieval_context_tokens,
total_state_tokens=storage["total_state_token_estimate"],
core_storage_bytes=storage["core_storage_bytes"],
audit_storage_bytes=storage["audit_storage_bytes"],
core_state_token_estimate=storage["core_state_token_estimate"],
audit_state_token_estimate=storage["audit_state_token_estimate"],
lightweight_stats=bool(self.lightweight_stats),
**self.graph.summary(),
)
def storage_bytes(self) -> int:
self._reload_graph()
return self._storage_breakdown()["storage_bytes"]
def build_prompt_context(self, query: str, *, top_k: int = 8) -> Dict[str, Any]:
retrieval = self.retrieve(query, top_k=top_k)
retrieval_payload = retrieval.to_dict()
stats = self.stats()
context_summary = _graph_prompt_state_summary(self.graph, retrieval)
prompt_context_payload = {
"query": query,
"retrieval": retrieval_payload,
"context_summary": context_summary,
}
prompt_context_chars = len(json.dumps(prompt_context_payload, ensure_ascii=False))
prompt_context_tokens_est = _estimate_tokens(json.dumps(prompt_context_payload, ensure_ascii=False))
return {
"mode": "graph_session_memory_v2",
"query": query,
"retrieval": retrieval_payload,
"stats": stats,
"state": context_summary,
"context_summary": context_summary,
"prompt_context_chars": int(prompt_context_chars),
"prompt_context_tokens_est": int(prompt_context_tokens_est),
"context_truncated": bool(context_summary.get("context_truncated", False)),
"truncation_reason": _clean_text(context_summary.get("truncation_reason", "")),
}
class SummaryWindowMemoryAdapter(MemoryAdapter):
name = "summary_window_memory"
def __init__(self, *, window_size: int = 24, auto_extract: bool = False) -> None:
self.extractor = SessionMemoryExtractor()
self.window_size = max(4, int(window_size))
self.turn_index = 0
self.active_slots: Dict[str, SessionMemoryRecordV2] = {}
self.recent_turns: deque[Dict[str, Any]] = deque(maxlen=self.window_size)
self.auto_extract = bool(auto_extract)
self._last_retrieval_context_tokens = 0
def reset(self) -> None:
self.turn_index = 0
self.active_slots = {}
self.recent_turns = deque(maxlen=self.window_size)
self._last_retrieval_context_tokens = 0
def ingest_turn(
self,
user_text: str,
assistant_text: str = "",
*,
answer_payload: Dict[str, Any] | None = None,
extraction_result: Dict[str, Any] | None = None,
) -> None:
self.turn_index += 1
records = _build_turn_records(
self.extractor,
user_text=user_text,
answer_payload=answer_payload,
extraction_result=extraction_result,
turn_index=self.turn_index,
allow_auto_extract=self.auto_extract,
)
for record in records:
previous = self.active_slots.get(record.slot_key)
if previous:
previous.state = "superseded"
record.supersedes.append(previous.memory_id)
record.state = "active"
self.active_slots[record.slot_key] = record
self.recent_turns.append(
{
"turn_index": self.turn_index,
"text": _clean_text(user_text),
"assistant": _clean_text(assistant_text),
}
)
def retrieve(self, query: str, *, top_k: int = 6) -> MemoryRetrieval:
start = time.perf_counter()
query_tokens = set(_tokenize(query))
hints = set(infer_category_hints(query))
scored: List[tuple[float, SessionMemoryRecordV2]] = []
for record in self.active_slots.values():
token_set = record.token_set()
overlap = len(query_tokens & token_set) if query_tokens and token_set else 0
score = overlap / max(1, len(query_tokens | token_set)) if query_tokens and token_set else 0.0
if hints and record.category in hints:
score += 0.22
if record.slot_key.lower() in _normalize(query):
score += 0.12
score += min(0.08, record.turn_index * 0.0004)
score += 0.2
if score > 0:
scored.append((score, record))
scored.sort(key=lambda item: (item[0], item[1].turn_index), reverse=True)
selected_records = [record for _, record in scored[:top_k]]
hits = [
MemoryHit(
memory_id=record.memory_id,
category=record.category,
value=record.value,
relation=record.relation,
anchors=list(record.anchor_concepts),
score=float(score),
source_kind=record.source_kind,
slot_key=record.slot_key,
state=record.state,
turn_index=record.turn_index,
metadata={"window_size": self.window_size},
)
for score, record in scored[:top_k]
]
concepts = []
relations = []
for hit in hits:
concepts.append({"concept": hit.value, "type": hit.category, "source_kind": hit.source_kind})
for anchor in hit.anchors[:2]:
concepts.append({"concept": anchor, "type": "context", "source_kind": hit.source_kind})
relation = _relation_hit(hit, weight_bias=0.06)
if relation:
relations.append(relation)
retrieval_context_tokens = _estimate_tokens_from_hits(hits)
self._last_retrieval_context_tokens = retrieval_context_tokens
return MemoryRetrieval(
concepts=concepts,
relations=relations,
hits=hits,
active_hits=list(hits),
retrieval_seconds=time.perf_counter() - start,
context_token_estimate=retrieval_context_tokens,
retrieval_context_token_estimate=retrieval_context_tokens,
metadata={
"records": len(self.active_slots),
"window_size": self.window_size,
"recent_turns": len(self.recent_turns),
},
)
def stats(self) -> Dict[str, Any]:
payload = {
"active_slots": {slot: record.to_dict() for slot, record in self.active_slots.items()},
"recent_turns": list(self.recent_turns),
}
total_state_tokens = _estimate_tokens(json.dumps(payload, ensure_ascii=False))
return _state_stats(
storage_bytes=self.storage_bytes(),
retrieval_context_tokens=self._last_retrieval_context_tokens,
total_state_tokens=total_state_tokens,
records=len(self.active_slots),
active_slots=len(self.active_slots),
recent_turns=len(self.recent_turns),
)
def storage_bytes(self) -> int:
payload = {
"active_slots": {slot: record.to_dict() for slot, record in self.active_slots.items()},
"recent_turns": list(self.recent_turns),
}
return len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
def build_prompt_context(self, query: str, *, top_k: int = 8) -> Dict[str, Any]:
return {
"mode": "summary_window_memory",
"query": query,
"retrieval": self.retrieve(query, top_k=top_k).to_dict(),
"stats": self.stats(),
"state": {
"active_slots": {slot: record.to_dict() for slot, record in self.active_slots.items()},
"recent_turns": list(self.recent_turns),
},
}
@dataclass(slots=True)
class _VectorRecord:
memory_id: str
category: str
value: str
relation: str
anchors: List[str]
tokens: List[str]
turn_index: int
slot_key: str = ""
active: bool = True
source_kind: str = "vector_memory"
metadata: Dict[str, Any] = field(default_factory=dict)
class VectorRAGMemoryAdapter(MemoryAdapter):
name = "vector_rag_memory"
def __init__(self, *, auto_extract: bool = False) -> None:
self.extractor = SessionMemoryExtractor()
self.records: List[_VectorRecord] = []
self.turn_index = 0
self.auto_extract = bool(auto_extract)
self._last_retrieval_context_tokens = 0
def reset(self) -> None:
self.records = []
self.turn_index = 0
self._last_retrieval_context_tokens = 0
def ingest_turn(
self,
user_text: str,
assistant_text: str = "",
*,
answer_payload: Dict[str, Any] | None = None,
extraction_result: Dict[str, Any] | None = None,
) -> None:
_ = assistant_text
self.turn_index += 1
records = _build_turn_records(
self.extractor,
user_text=user_text,
answer_payload=answer_payload,
extraction_result=extraction_result,
turn_index=self.turn_index,
allow_auto_extract=self.auto_extract,
)
for record in records:
if record.slot_key:
for previous in self.records:
if previous.slot_key == record.slot_key and previous.active:
previous.active = False
self.records.append(
_VectorRecord(
memory_id=record.memory_id,
category=record.category,
value=record.value,
relation=record.relation,
anchors=list(record.anchor_concepts),
tokens=list(record.token_set()),
turn_index=record.turn_index,
slot_key=record.slot_key,
active=record.state == "active",
source_kind=record.source_kind,
metadata=dict(record.metadata),
)
)
def retrieve(self, query: str, *, top_k: int = 6) -> MemoryRetrieval:
start = time.perf_counter()
query_tokens = set(_tokenize(query))
hints = set(infer_category_hints(query))
scored: List[tuple[float, _VectorRecord]] = []
for record in self.records:
token_set = set(record.tokens)
overlap = len(query_tokens & token_set) if query_tokens and token_set else 0
score = overlap / max(1, len(query_tokens | token_set)) if query_tokens and token_set else 0.0
if hints and record.category in hints:
score += 0.18
if record.slot_key and record.slot_key.lower() in _normalize(query):
score += 0.1
score += min(0.12, record.turn_index * 0.0004)
score += 0.18 if record.active else -0.25
if score > 0:
scored.append((score, record))
if not scored:
for record in self.records[-top_k:]:
scored.append((0.05 + (0.15 if record.active else 0.0), record))
scored.sort(key=lambda item: (item[0], item[1].active, item[1].turn_index), reverse=True)
hits = [
MemoryHit(
memory_id=record.memory_id,
category=record.category,
value=record.value,
relation=record.relation,
anchors=list(record.anchors),
score=float(score),
source_kind=record.source_kind,
slot_key=record.slot_key,
state="active" if record.active else "superseded",
turn_index=record.turn_index,
metadata=dict(record.metadata),
)
for score, record in scored[:top_k]
]
concepts = []
relations = []
for hit in hits:
concepts.append({"concept": hit.value, "type": hit.category, "source_kind": hit.source_kind})
for anchor in hit.anchors[:2]:
concepts.append({"concept": anchor, "type": "context", "source_kind": hit.source_kind})
relation = _relation_hit(hit)
if relation:
relations.append(relation)
retrieval_context_tokens = _estimate_tokens_from_hits(hits)
self._last_retrieval_context_tokens = retrieval_context_tokens
return MemoryRetrieval(
concepts=concepts,
relations=relations,
hits=hits,
active_hits=[hit for hit in hits if hit.state == "active"],
history_hits=[hit for hit in hits if hit.state != "active"],
retrieval_seconds=time.perf_counter() - start,
context_token_estimate=retrieval_context_tokens,
retrieval_context_token_estimate=retrieval_context_tokens,
metadata={
"records": len(self.records),
"active_records": sum(1 for item in self.records if item.active),
},
)
def stats(self) -> Dict[str, Any]:
payload = [
{
"memory_id": record.memory_id,
"category": record.category,
"value": record.value,
"relation": record.relation,
"anchors": record.anchors,
"slot_key": record.slot_key,
"active": record.active,
}
for record in self.records
]
total_state_tokens = _estimate_tokens(json.dumps(payload, ensure_ascii=False))
return _state_stats(
storage_bytes=self.storage_bytes(),
retrieval_context_tokens=self._last_retrieval_context_tokens,
total_state_tokens=total_state_tokens,
records=len(self.records),
active_records=sum(1 for item in self.records if item.active),
)
def storage_bytes(self) -> int:
payload = [
{
"memory_id": record.memory_id,
"category": record.category,
"value": record.value,
"relation": record.relation,
"anchors": record.anchors,
"slot_key": record.slot_key,
"active": record.active,
}
for record in self.records
]
return len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
def build_prompt_context(self, query: str, *, top_k: int = 8) -> Dict[str, Any]:
state = [
{
"memory_id": record.memory_id,
"category": record.category,
"value": record.value,
"relation": record.relation,
"anchors": list(record.anchors),
"slot_key": record.slot_key,
"active": bool(record.active),
"turn_index": int(record.turn_index),
}
for record in self.records
]
return {
"mode": "vector_rag_memory",
"query": query,
"retrieval": self.retrieve(query, top_k=top_k).to_dict(),
"stats": self.stats(),
"state": state,
}
class FullHistoryMemoryAdapter(MemoryAdapter):
name = "full_history_memory"
def __init__(self) -> None:
self.turns: List[Dict[str, str]] = []
self._last_retrieval_context_tokens = 0
def reset(self) -> None:
self.turns = []
self._last_retrieval_context_tokens = 0
def ingest_turn(
self,
user_text: str,
assistant_text: str = "",
*,
answer_payload: Dict[str, Any] | None = None,
extraction_result: Dict[str, Any] | None = None,
) -> None:
_ = answer_payload, extraction_result
self.turns.append({"user": _clean_text(user_text), "assistant": _clean_text(assistant_text)})
def retrieve(self, query: str, *, top_k: int = 6) -> MemoryRetrieval:
start = time.perf_counter()
query_tokens = set(_tokenize(query))
scored: List[tuple[float, Dict[str, str], int]] = []
for index, turn in enumerate(self.turns):
combined = f"{turn.get('user', '')} {turn.get('assistant', '')}"
token_set = set(_tokenize(combined))
if not token_set:
continue
overlap = len(query_tokens & token_set) if query_tokens else 0
score = overlap / max(1, len(query_tokens | token_set)) if query_tokens else 0.0
scored.append((score, turn, index))
scored.sort(key=lambda item: (item[0], item[2]), reverse=True)
hits = [
MemoryHit(
memory_id=f"turn:{index}",
category="history_turn",
value=turn.get("user", ""),
relation="conversation_context",
anchors=[turn.get("assistant", "")] if turn.get("assistant") else [],
score=float(score),
source_kind="full_history",
slot_key=f"turn.{index}",
state="active",
turn_index=index + 1,
)
for score, turn, index in scored[:top_k]
if turn.get("user", "")
]
retrieval_context_tokens = _estimate_tokens(json.dumps(self.turns, ensure_ascii=False))
self._last_retrieval_context_tokens = retrieval_context_tokens
return MemoryRetrieval(
hits=hits,
active_hits=list(hits),
retrieval_seconds=time.perf_counter() - start,
context_token_estimate=retrieval_context_tokens,
retrieval_context_token_estimate=retrieval_context_tokens,
metadata={"records": len(self.turns)},
)
def stats(self) -> Dict[str, Any]:
total_state_tokens = _estimate_tokens(json.dumps(self.turns, ensure_ascii=False))
return _state_stats(
storage_bytes=self.storage_bytes(),
retrieval_context_tokens=self._last_retrieval_context_tokens,
total_state_tokens=total_state_tokens,
records=len(self.turns),
)
def storage_bytes(self) -> int:
return len(json.dumps(self.turns, ensure_ascii=False).encode("utf-8"))
def build_prompt_context(self, query: str, *, top_k: int = 8) -> Dict[str, Any]:
return {
"mode": "full_history_memory",
"query": query,
"retrieval": self.retrieve(query, top_k=top_k).to_dict(),
"stats": self.stats(),
"state": {
"turns": list(self.turns),
},
}