GitHub Actions commited on
Commit
acfcc03
·
1 Parent(s): 5941cd9

Deploy 91b3f47

Browse files
app/api/chat.py CHANGED
@@ -2,6 +2,7 @@ import asyncio
2
  import json
3
  import re
4
  import time
 
5
  from fastapi import APIRouter, Request, Depends
6
  from fastapi.responses import StreamingResponse
7
 
@@ -11,6 +12,7 @@ from app.security.rate_limiter import chat_rate_limit
11
  from app.security.jwt_auth import verify_jwt
12
 
13
  router = APIRouter()
 
14
 
15
  # Keep-alive interval for SSE when upstream nodes are still working.
16
  # Prevents edge/proxy idle timeouts on long retrieval/generation turns.
@@ -22,14 +24,14 @@ _EXPANSION_TIMEOUT_SECONDS: float = 0.60
22
 
23
  # Phrases a visitor uses when telling the bot it gave a wrong answer.
24
  # Matched on the lowercased raw message before any LLM call — O(1), zero cost.
25
- _CRITICISM_SIGNALS: frozenset[str] = frozenset({
26
  "that's wrong", "thats wrong", "you're wrong", "youre wrong",
27
  "not right", "wrong answer", "you got it wrong", "that is wrong",
28
  "that's incorrect", "you're incorrect", "thats incorrect", "youre incorrect",
29
  "fix that", "fix your answer", "actually no", "no that's", "no thats",
30
  "that was wrong", "your answer was wrong", "wrong information",
31
  "incorrect information", "that's not right", "thats not right",
32
- })
33
 
34
 
35
  def _is_criticism(message: str) -> bool:
@@ -52,11 +54,11 @@ def _filter_sources_by_citations(answer: str, sources: list) -> list:
52
  if not cited_nums:
53
  return sources
54
 
55
- max_cited = max(cited_nums)
56
- if max_cited > len(sources):
57
  return sources
58
 
59
- return [s for i, s in enumerate(sources, start=1) if i in cited_nums]
60
 
61
 
62
  async def _generate_follow_ups(
@@ -124,7 +126,6 @@ async def _update_summary_async(
124
  previous_summary: str | None,
125
  query: str,
126
  answer: str,
127
- processing_api_key: str | None,
128
  ) -> None:
129
  """
130
  Triggered post-response to update the rolling conversation summary.
@@ -135,7 +136,6 @@ async def _update_summary_async(
135
  previous_summary=previous_summary or "",
136
  new_turn_q=query,
137
  new_turn_a=answer[:600], # cap answer chars sent to Gemini
138
- processing_api_key=processing_api_key,
139
  )
140
  if new_summary:
141
  conv_store.set_summary(session_id, new_summary)
@@ -143,12 +143,17 @@ async def _update_summary_async(
143
  pass
144
 
145
 
 
 
 
 
 
146
  @router.post("")
147
  @chat_rate_limit()
148
  async def chat_endpoint(
149
  request: Request,
150
  request_data: ChatRequest,
151
- token_payload: dict = Depends(verify_jwt),
152
  ) -> StreamingResponse:
153
  """Stream RAG answer as typed SSE events.
154
 
@@ -344,19 +349,22 @@ async def chat_endpoint(
344
  # ── Follow-up questions ────────────────────────────────────────────
345
  # Generated after the done event so it never delays answer delivery.
346
  if final_answer and not await request.is_disconnected():
347
- follow_ups = await _generate_follow_ups(
348
- request_data.message, final_answer, final_sources, llm_client
349
  )
 
 
 
 
 
 
350
  if follow_ups:
351
  yield f"event: follow_ups\ndata: {json.dumps({'questions': follow_ups})}\n\n"
352
 
353
  # Stage 2: update rolling summary asynchronously — fired after the
354
  # response is fully delivered so it adds zero latency to the turn.
355
  if final_answer and gemini_client and gemini_client.is_configured:
356
- processing_key = getattr(
357
- request.app.state, "gemini_processing_api_key", None
358
- )
359
- asyncio.create_task(
360
  _update_summary_async(
361
  conv_store=conv_store,
362
  gemini_client=gemini_client,
@@ -364,9 +372,9 @@ async def chat_endpoint(
364
  previous_summary=conversation_summary,
365
  query=request_data.message,
366
  answer=final_answer,
367
- processing_api_key=processing_key,
368
  )
369
  )
 
370
 
371
  except Exception as exc:
372
  yield f"data: {json.dumps({'error': str(exc) or 'Generation failed'})}\n\n"
 
2
  import json
3
  import re
4
  import time
5
+ from typing import Annotated
6
  from fastapi import APIRouter, Request, Depends
7
  from fastapi.responses import StreamingResponse
8
 
 
12
  from app.security.jwt_auth import verify_jwt
13
 
14
  router = APIRouter()
15
+ _BACKGROUND_TASKS: set[asyncio.Task[object]] = set()
16
 
17
  # Keep-alive interval for SSE when upstream nodes are still working.
18
  # Prevents edge/proxy idle timeouts on long retrieval/generation turns.
 
24
 
25
  # Phrases a visitor uses when telling the bot it gave a wrong answer.
26
  # Matched on the lowercased raw message before any LLM call — O(1), zero cost.
27
+ _CRITICISM_SIGNALS: tuple[str, ...] = (
28
  "that's wrong", "thats wrong", "you're wrong", "youre wrong",
29
  "not right", "wrong answer", "you got it wrong", "that is wrong",
30
  "that's incorrect", "you're incorrect", "thats incorrect", "youre incorrect",
31
  "fix that", "fix your answer", "actually no", "no that's", "no thats",
32
  "that was wrong", "your answer was wrong", "wrong information",
33
  "incorrect information", "that's not right", "thats not right",
34
+ )
35
 
36
 
37
  def _is_criticism(message: str) -> bool:
 
54
  if not cited_nums:
55
  return sources
56
 
57
+ valid_cited_nums = {num for num in cited_nums if 1 <= num <= len(sources)}
58
+ if not valid_cited_nums:
59
  return sources
60
 
61
+ return [s for i, s in enumerate(sources, start=1) if i in valid_cited_nums]
62
 
63
 
64
  async def _generate_follow_ups(
 
126
  previous_summary: str | None,
127
  query: str,
128
  answer: str,
 
129
  ) -> None:
130
  """
131
  Triggered post-response to update the rolling conversation summary.
 
136
  previous_summary=previous_summary or "",
137
  new_turn_q=query,
138
  new_turn_a=answer[:600], # cap answer chars sent to Gemini
 
139
  )
140
  if new_summary:
141
  conv_store.set_summary(session_id, new_summary)
 
143
  pass
144
 
145
 
146
+ def _track_background_task(task: asyncio.Task[object]) -> None:
147
+ _BACKGROUND_TASKS.add(task)
148
+ task.add_done_callback(_BACKGROUND_TASKS.discard)
149
+
150
+
151
  @router.post("")
152
  @chat_rate_limit()
153
  async def chat_endpoint(
154
  request: Request,
155
  request_data: ChatRequest,
156
+ token_payload: Annotated[dict, Depends(verify_jwt)],
157
  ) -> StreamingResponse:
158
  """Stream RAG answer as typed SSE events.
159
 
 
349
  # ── Follow-up questions ────────────────────────────────────────────
350
  # Generated after the done event so it never delays answer delivery.
351
  if final_answer and not await request.is_disconnected():
352
+ follow_up_task: asyncio.Task[list[str]] = asyncio.create_task(
353
+ _generate_follow_ups(request_data.message, final_answer, final_sources, llm_client)
354
  )
355
+ _track_background_task(follow_up_task)
356
+ try:
357
+ follow_ups = await asyncio.wait_for(follow_up_task, timeout=0.25)
358
+ except Exception:
359
+ follow_up_task.cancel()
360
+ follow_ups = []
361
  if follow_ups:
362
  yield f"event: follow_ups\ndata: {json.dumps({'questions': follow_ups})}\n\n"
363
 
364
  # Stage 2: update rolling summary asynchronously — fired after the
365
  # response is fully delivered so it adds zero latency to the turn.
366
  if final_answer and gemini_client and gemini_client.is_configured:
367
+ summary_task: asyncio.Task[None] = asyncio.create_task(
 
 
 
368
  _update_summary_async(
369
  conv_store=conv_store,
370
  gemini_client=gemini_client,
 
372
  previous_summary=conversation_summary,
373
  query=request_data.message,
374
  answer=final_answer,
 
375
  )
376
  )
377
+ _track_background_task(summary_task)
378
 
379
  except Exception as exc:
380
  yield f"data: {json.dumps({'error': str(exc) or 'Generation failed'})}\n\n"
app/main.py CHANGED
@@ -68,6 +68,24 @@ def _normalize_qdrant_url(url: str) -> str:
68
  return raw
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  async def _qdrant_keepalive_loop(
72
  qdrant: QdrantClient,
73
  interval_seconds: int,
@@ -111,15 +129,7 @@ async def lifespan(app: FastAPI):
111
  # DagsHub/MLflow experiment tracking — optional, only active when token is set.
112
  # In prod with DAGSHUB_TOKEN set, experiments are tracked at dagshub.com.
113
  # In local or test environments, MLflow is a no-op.
114
- if settings.DAGSHUB_TOKEN:
115
- import dagshub
116
- dagshub.init(
117
- repo_owner=settings.DAGSHUB_REPO.split("/")[0],
118
- repo_name=settings.DAGSHUB_REPO.split("/")[1],
119
- mlflow=True,
120
- dvc=False,
121
- )
122
- logger.info("DagsHub MLflow tracking enabled | repo=%s", settings.DAGSHUB_REPO)
123
 
124
  embedder = Embedder(remote_url=settings.EMBEDDER_URL, environment=settings.ENVIRONMENT)
125
  reranker = Reranker(remote_url=settings.RERANKER_URL, environment=settings.ENVIRONMENT)
 
68
  return raw
69
 
70
 
71
+ def _setup_dagshub_tracking(settings) -> None:
72
+ if not settings.DAGSHUB_TOKEN:
73
+ return
74
+
75
+ try:
76
+ import dagshub
77
+
78
+ dagshub.init(
79
+ repo_owner=settings.DAGSHUB_REPO.split("/")[0],
80
+ repo_name=settings.DAGSHUB_REPO.split("/")[1],
81
+ mlflow=True,
82
+ dvc=False,
83
+ )
84
+ logger.info("DagsHub MLflow tracking enabled | repo=%s", settings.DAGSHUB_REPO)
85
+ except Exception as exc:
86
+ logger.warning("DagsHub MLflow tracking disabled: %s", exc)
87
+
88
+
89
  async def _qdrant_keepalive_loop(
90
  qdrant: QdrantClient,
91
  interval_seconds: int,
 
129
  # DagsHub/MLflow experiment tracking — optional, only active when token is set.
130
  # In prod with DAGSHUB_TOKEN set, experiments are tracked at dagshub.com.
131
  # In local or test environments, MLflow is a no-op.
132
+ _setup_dagshub_tracking(settings)
 
 
 
 
 
 
 
 
133
 
134
  embedder = Embedder(remote_url=settings.EMBEDDER_URL, environment=settings.ENVIRONMENT)
135
  reranker = Reranker(remote_url=settings.RERANKER_URL, environment=settings.ENVIRONMENT)
app/pipeline/graph.py CHANGED
@@ -113,7 +113,7 @@ def build_pipeline(services: dict) -> CompiledStateGraph:
113
  graph.add_node("guard", make_guard_node(services["classifier"]))
114
  graph.add_node("enumerate_query", make_enumerate_query_node(services["vector_store"]))
115
  graph.add_node("cache", make_cache_node(services["cache"], services["embedder"]))
116
- graph.add_node("gemini_fast", make_gemini_fast_node(services["gemini"]))
117
  graph.add_node("retrieve", make_retrieve_node(
118
  services["vector_store"],
119
  services["embedder"],
 
113
  graph.add_node("guard", make_guard_node(services["classifier"]))
114
  graph.add_node("enumerate_query", make_enumerate_query_node(services["vector_store"]))
115
  graph.add_node("cache", make_cache_node(services["cache"], services["embedder"]))
116
+ graph.add_node("gemini_fast", make_gemini_fast_node())
117
  graph.add_node("retrieve", make_retrieve_node(
118
  services["vector_store"],
119
  services["embedder"],
app/pipeline/nodes/cache.py CHANGED
@@ -29,7 +29,7 @@ from app.services.semantic_cache import SemanticCache
29
  # prior turn, and excluding them would bypass cache on most portfolio queries.
30
  _REFERENCE_TOKENS: frozenset[str] = frozenset({
31
  "that", "it", "its", "they", "their", "those",
32
- "this", "these", "them", "there", "then",
33
  })
34
 
35
 
@@ -67,9 +67,20 @@ def make_cache_node(cache: SemanticCache, embedder) -> Callable[[PipelineState],
67
  cached = await cache.get(query_embedding)
68
  if cached:
69
  writer({"type": "status", "label": "Found a recent answer, loading..."})
70
- # Emit the full cached answer as a single token event the cache
71
- # returns complete text, not a stream, so one event is correct.
72
- writer({"type": "token", "text": cached})
 
 
 
 
 
 
 
 
 
 
 
73
  return {
74
  "answer": cached,
75
  "cached": True,
 
29
  # prior turn, and excluding them would bypass cache on most portfolio queries.
30
  _REFERENCE_TOKENS: frozenset[str] = frozenset({
31
  "that", "it", "its", "they", "their", "those",
32
+ "this", "these", "them",
33
  })
34
 
35
 
 
67
  cached = await cache.get(query_embedding)
68
  if cached:
69
  writer({"type": "status", "label": "Found a recent answer, loading..."})
70
+ # Stream cached answers in short chunks so the SSE contract stays
71
+ # consistent with non-cache paths.
72
+ words = cached.split()
73
+ chunk: list[str] = []
74
+ chunk_len = 0
75
+ for word in words:
76
+ chunk.append(word)
77
+ chunk_len += len(word) + 1
78
+ if chunk_len >= 80:
79
+ writer({"type": "token", "text": " ".join(chunk)})
80
+ chunk = []
81
+ chunk_len = 0
82
+ if chunk:
83
+ writer({"type": "token", "text": " ".join(chunk)})
84
  return {
85
  "answer": cached,
86
  "cached": True,
app/pipeline/nodes/gemini_fast.py CHANGED
@@ -5,6 +5,7 @@ and citation-capable. No parametric Gemini answer generation is used here.
5
  """
6
  from __future__ import annotations
7
 
 
8
  import logging
9
  import re
10
  from typing import Any
@@ -12,7 +13,6 @@ from typing import Any
12
  from langgraph.config import get_stream_writer
13
 
14
  from app.models.pipeline import PipelineState
15
- from app.services.gemini_client import GeminiClient
16
 
17
  logger = logging.getLogger(__name__)
18
 
@@ -43,14 +43,17 @@ _SMALL_TALK_ANSWER = (
43
  "and I'll find the details for you."
44
  )
45
 
46
- def make_gemini_fast_node(gemini_client: GeminiClient) -> Any:
47
  """
48
  Returns a LangGraph-compatible async node function.
49
- ``gemini_client`` is injected at startup from app.state.gemini_client.
50
  """
 
51
 
52
  async def gemini_fast(state: PipelineState) -> dict:
 
53
  writer = get_stream_writer()
 
 
54
  writer({"type": "status", "label": "Thinking about your question directly..."})
55
 
56
  query = state["query"]
 
5
  """
6
  from __future__ import annotations
7
 
8
+ import asyncio
9
  import logging
10
  import re
11
  from typing import Any
 
13
  from langgraph.config import get_stream_writer
14
 
15
  from app.models.pipeline import PipelineState
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
43
  "and I'll find the details for you."
44
  )
45
 
46
+ def make_gemini_fast_node(gemini_client: Any | None = None) -> Any:
47
  """
48
  Returns a LangGraph-compatible async node function.
 
49
  """
50
+ configured = bool(gemini_client and getattr(gemini_client, "is_configured", False))
51
 
52
  async def gemini_fast(state: PipelineState) -> dict:
53
+ await asyncio.sleep(0)
54
  writer = get_stream_writer()
55
+ if configured:
56
+ logger.debug("Gemini client is configured, but fast-path remains deterministic.")
57
  writer({"type": "status", "label": "Thinking about your question directly..."})
58
 
59
  query = state["query"]
app/pipeline/nodes/generate.py CHANGED
@@ -336,9 +336,9 @@ def _build_low_trust_fallback(query: str, source_refs: list[SourceRef]) -> str:
336
 
337
  if _VERSION_PARITY_RE.search(query):
338
  return (
339
- "The indexed sources include related details [1], but they do not explicitly "
340
  "confirm whether the GitHub code and live demo are currently in sync, so version parity "
341
- "cannot be verified from indexed content alone [1]."
342
  )
343
 
344
  return (
@@ -369,14 +369,14 @@ def make_generate_node(llm_client: LLMClient, gemini_client=None) -> Callable[[P
369
  # needed, just reliable numbered-list formatting with one citation per item.
370
  if state.get("is_enumeration_query") and reranked_chunks:
371
  writer({"type": "status", "label": "Formatting complete list..."})
372
- context_parts: list[str] = []
373
  source_refs: list[SourceRef] = []
374
  for i, chunk in enumerate(reranked_chunks, start=1):
375
  meta = chunk["metadata"]
376
  header = f"[{i}] {meta.get('source_title', 'Item')}"
377
  if meta.get("source_url"):
378
  header += f" ({meta['source_url']})"
379
- context_parts.append(f"{header}\n{chunk['text'][:300]}")
380
  source_refs.append(
381
  SourceRef(
382
  title=meta.get("source_title", ""),
@@ -385,7 +385,7 @@ def make_generate_node(llm_client: LLMClient, gemini_client=None) -> Callable[[P
385
  source_type=meta.get("source_type", ""),
386
  )
387
  )
388
- context_block_enum = "\n\n".join(context_parts)
389
  prompt_enum = f"Items fetched from database:\n{context_block_enum}\n\nVisitor request: {query}"
390
  stream = llm_client.complete_with_complexity(
391
  prompt=prompt_enum,
 
336
 
337
  if _VERSION_PARITY_RE.search(query):
338
  return (
339
+ "The indexed sources include related details, but they do not explicitly "
340
  "confirm whether the GitHub code and live demo are currently in sync, so version parity "
341
+ "cannot be verified from indexed content alone."
342
  )
343
 
344
  return (
 
369
  # needed, just reliable numbered-list formatting with one citation per item.
370
  if state.get("is_enumeration_query") and reranked_chunks:
371
  writer({"type": "status", "label": "Formatting complete list..."})
372
+ enum_context_parts: list[str] = []
373
  source_refs: list[SourceRef] = []
374
  for i, chunk in enumerate(reranked_chunks, start=1):
375
  meta = chunk["metadata"]
376
  header = f"[{i}] {meta.get('source_title', 'Item')}"
377
  if meta.get("source_url"):
378
  header += f" ({meta['source_url']})"
379
+ enum_context_parts.append(f"{header}\n{chunk['text'][:300]}")
380
  source_refs.append(
381
  SourceRef(
382
  title=meta.get("source_title", ""),
 
385
  source_type=meta.get("source_type", ""),
386
  )
387
  )
388
+ context_block_enum = "\n\n".join(enum_context_parts)
389
  prompt_enum = f"Items fetched from database:\n{context_block_enum}\n\nVisitor request: {query}"
390
  stream = llm_client.complete_with_complexity(
391
  prompt=prompt_enum,
app/pipeline/nodes/guard.py CHANGED
@@ -37,7 +37,7 @@ def make_guard_node(classifier: GuardClassifier) -> Callable[[PipelineState], di
37
  }
38
 
39
  # 2. Classify (scope evaluation).
40
- is_safe, score = classifier.is_in_scope(clean_query)
41
 
42
  if not is_safe:
43
  return {
 
37
  }
38
 
39
  # 2. Classify (scope evaluation).
40
+ is_safe, _ = classifier.is_in_scope(clean_query)
41
 
42
  if not is_safe:
43
  return {
app/pipeline/nodes/log_eval.py CHANGED
@@ -13,6 +13,56 @@ logger = logging.getLogger(__name__)
13
  _PENDING_TASKS: set[asyncio.Task[None]] = set()
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def _source_hit_proxy(state: PipelineState) -> int:
17
  reranked_chunks = state.get("reranked_chunks", [])
18
  chunk_count = len(reranked_chunks)
@@ -35,11 +85,9 @@ def make_log_eval_node(db_path: str) -> Callable[[PipelineState], dict]:
35
  only RAG interactions have chunk associations for valid training pairs.
36
  """
37
 
38
- def _write_to_sqlite(state: PipelineState) -> int:
39
- db_dir = os.path.dirname(db_path)
40
- if db_dir:
41
- os.makedirs(db_dir, exist_ok=True)
42
 
 
43
  chunks_used = json.dumps(
44
  [c["metadata"]["doc_id"] for c in state.get("reranked_chunks", [])]
45
  )
@@ -61,54 +109,6 @@ def make_log_eval_node(db_path: str) -> Callable[[PipelineState], dict]:
61
  source_hit_proxy = _source_hit_proxy(state)
62
 
63
  with sqlite3.connect(db_path) as conn:
64
- conn.execute(
65
- """
66
- CREATE TABLE IF NOT EXISTS interactions (
67
- id INTEGER PRIMARY KEY AUTOINCREMENT,
68
- timestamp TEXT,
69
- session_id TEXT,
70
- query TEXT,
71
- answer TEXT,
72
- chunks_used TEXT,
73
- rerank_scores TEXT,
74
- reranked_chunks_json TEXT,
75
- latency_ms INTEGER,
76
- cached BOOLEAN,
77
- feedback INTEGER DEFAULT 0,
78
- path TEXT DEFAULT 'rag',
79
- critic_groundedness INTEGER,
80
- critic_completeness INTEGER,
81
- critic_specificity INTEGER,
82
- critic_quality TEXT,
83
- is_enumeration_query BOOLEAN DEFAULT 0,
84
- source_hit_proxy INTEGER DEFAULT 0
85
- )
86
- """
87
- )
88
- # Idempotent schema upgrades for deployments that pre-date these columns.
89
- for col, definition in [
90
- ("reranked_chunks_json", "TEXT DEFAULT '[]'"),
91
- ("feedback", "INTEGER DEFAULT 0"),
92
- ("session_id", "TEXT DEFAULT ''"),
93
- # path column: old rows default to "rag" — they were all RAG interactions.
94
- ("path", "TEXT DEFAULT 'rag'"),
95
- # Stage 3 SELF-RAG critic scores
96
- ("critic_groundedness", "INTEGER"),
97
- ("critic_completeness", "INTEGER"),
98
- ("critic_specificity", "INTEGER"),
99
- ("critic_quality", "TEXT"),
100
- # Fix 1: enumeration classifier flag
101
- ("is_enumeration_query", "BOOLEAN DEFAULT 0"),
102
- # RC-13: retrieval diagnostics
103
- ("sibling_expansion_count", "INTEGER"),
104
- ("focused_source_type", "TEXT"),
105
- ("source_hit_proxy", "INTEGER DEFAULT 0"),
106
- ]:
107
- try:
108
- conn.execute(f"ALTER TABLE interactions ADD COLUMN {col} {definition}")
109
- except sqlite3.OperationalError:
110
- pass # Column already exists.
111
-
112
  cursor = conn.execute(
113
  """
114
  INSERT INTO interactions
@@ -142,6 +142,7 @@ def make_log_eval_node(db_path: str) -> Callable[[PipelineState], dict]:
142
  return cursor.lastrowid # type: ignore[return-value]
143
 
144
  def _build_loki_record(state: PipelineState) -> dict:
 
145
  return {
146
  "timestamp": datetime.now(tz=timezone.utc).isoformat(),
147
  "session_id": state.get("session_id", ""),
@@ -164,7 +165,16 @@ def make_log_eval_node(db_path: str) -> Callable[[PipelineState], dict]:
164
  "is_followup": state.get("is_followup", False),
165
  "is_audio_mode": state.get("is_audio_mode", False),
166
  "follow_ups": state.get("follow_ups", []),
167
- "reranked_chunks": state.get("reranked_chunks", []),
 
 
 
 
 
 
 
 
 
168
  "source_hit_proxy": _source_hit_proxy(state),
169
  }
170
 
 
13
  _PENDING_TASKS: set[asyncio.Task[None]] = set()
14
 
15
 
16
+ def _ensure_interactions_schema(db_path: str) -> None:
17
+ db_dir = os.path.dirname(db_path)
18
+ if db_dir:
19
+ os.makedirs(db_dir, exist_ok=True)
20
+
21
+ with sqlite3.connect(db_path) as conn:
22
+ conn.execute(
23
+ """
24
+ CREATE TABLE IF NOT EXISTS interactions (
25
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
26
+ timestamp TEXT,
27
+ session_id TEXT,
28
+ query TEXT,
29
+ answer TEXT,
30
+ chunks_used TEXT,
31
+ rerank_scores TEXT,
32
+ reranked_chunks_json TEXT,
33
+ latency_ms INTEGER,
34
+ cached BOOLEAN,
35
+ feedback INTEGER DEFAULT 0,
36
+ path TEXT DEFAULT 'rag',
37
+ critic_groundedness INTEGER,
38
+ critic_completeness INTEGER,
39
+ critic_specificity INTEGER,
40
+ critic_quality TEXT,
41
+ is_enumeration_query BOOLEAN DEFAULT 0,
42
+ source_hit_proxy INTEGER DEFAULT 0
43
+ )
44
+ """
45
+ )
46
+ for col, definition in [
47
+ ("reranked_chunks_json", "TEXT DEFAULT '[]'"),
48
+ ("feedback", "INTEGER DEFAULT 0"),
49
+ ("session_id", "TEXT DEFAULT ''"),
50
+ ("path", "TEXT DEFAULT 'rag'"),
51
+ ("critic_groundedness", "INTEGER"),
52
+ ("critic_completeness", "INTEGER"),
53
+ ("critic_specificity", "INTEGER"),
54
+ ("critic_quality", "TEXT"),
55
+ ("is_enumeration_query", "BOOLEAN DEFAULT 0"),
56
+ ("sibling_expansion_count", "INTEGER"),
57
+ ("focused_source_type", "TEXT"),
58
+ ("source_hit_proxy", "INTEGER DEFAULT 0"),
59
+ ]:
60
+ try:
61
+ conn.execute(f"ALTER TABLE interactions ADD COLUMN {col} {definition}")
62
+ except sqlite3.OperationalError:
63
+ pass
64
+
65
+
66
  def _source_hit_proxy(state: PipelineState) -> int:
67
  reranked_chunks = state.get("reranked_chunks", [])
68
  chunk_count = len(reranked_chunks)
 
85
  only RAG interactions have chunk associations for valid training pairs.
86
  """
87
 
88
+ _ensure_interactions_schema(db_path)
 
 
 
89
 
90
+ def _write_to_sqlite(state: PipelineState) -> int:
91
  chunks_used = json.dumps(
92
  [c["metadata"]["doc_id"] for c in state.get("reranked_chunks", [])]
93
  )
 
109
  source_hit_proxy = _source_hit_proxy(state)
110
 
111
  with sqlite3.connect(db_path) as conn:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  cursor = conn.execute(
113
  """
114
  INSERT INTO interactions
 
142
  return cursor.lastrowid # type: ignore[return-value]
143
 
144
  def _build_loki_record(state: PipelineState) -> dict:
145
+ reranked_chunks = state.get("reranked_chunks", [])
146
  return {
147
  "timestamp": datetime.now(tz=timezone.utc).isoformat(),
148
  "session_id": state.get("session_id", ""),
 
165
  "is_followup": state.get("is_followup", False),
166
  "is_audio_mode": state.get("is_audio_mode", False),
167
  "follow_ups": state.get("follow_ups", []),
168
+ "chunk_count": len(reranked_chunks),
169
+ "top_chunk_doc_id": reranked_chunks[0]["metadata"].get("doc_id", "") if reranked_chunks else "",
170
+ "source_types_used": sorted(
171
+ {
172
+ c["metadata"].get("source_type", "")
173
+ for c in reranked_chunks
174
+ if c["metadata"].get("source_type", "")
175
+ }
176
+ ),
177
+ "rerank_scores": [c["metadata"].get("rerank_score", 0.0) for c in reranked_chunks],
178
  "source_hit_proxy": _source_hit_proxy(state),
179
  }
180
 
app/pipeline/nodes/retrieve.py CHANGED
@@ -7,6 +7,7 @@ from langgraph.config import get_stream_writer
7
 
8
  logger = logging.getLogger(__name__)
9
 
 
10
  from app.models.pipeline import PipelineState, Chunk
11
  from app.services.vector_store import VectorStore
12
  from app.services.embedder import Embedder
@@ -97,29 +98,6 @@ _CAPABILITY_QUERY_HINTS: frozenset[str] = frozenset(
97
  }
98
  )
99
 
100
- _BIOGRAPHY_QUERY_HINTS: frozenset[str] = frozenset(
101
- {
102
- "work",
103
- "experience",
104
- "employment",
105
- "career",
106
- "internship",
107
- "internships",
108
- "education",
109
- "degree",
110
- "university",
111
- "background",
112
- "resume",
113
- "cv",
114
- "company",
115
- "companies",
116
- "role",
117
- "roles",
118
- }
119
- )
120
-
121
- _BIO_SOURCE_TYPES: frozenset[str] = frozenset({"resume", "cv", "bio"})
122
-
123
  _NORMALISATION_STOPWORDS: frozenset[str] = frozenset(
124
  {
125
  "tell",
@@ -200,34 +178,58 @@ _FOCUS_VOCAB: frozenset[str] = frozenset(
200
  }
201
  )
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- def _edit_distance(a: str, b: str) -> int:
205
- la, lb = len(a), len(b)
206
- dp = list(range(lb + 1))
207
- for i in range(1, la + 1):
208
- prev = dp[0]
209
- dp[0] = i
210
- for j in range(1, lb + 1):
211
- cur = dp[j]
212
- cost = 0 if a[i - 1] == b[j - 1] else 1
213
- dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
214
- prev = cur
215
- return dp[lb]
216
 
217
 
218
  def _best_focus_replacement(token: str) -> str | None:
219
- best = None
220
- best_score = 99
221
- for candidate in _FOCUS_VOCAB:
222
- if token[0] != candidate[0]:
223
- continue
 
 
224
  if abs(len(token) - len(candidate)) > 1:
225
  continue
226
- score = _edit_distance(token, candidate)
227
- if score <= 2 and score < best_score:
228
- best_score = score
229
  best = candidate
230
- return best
 
 
 
 
231
 
232
 
233
  def _normalise_focus_typos(query: str) -> str:
@@ -265,11 +267,6 @@ def _is_capability_query(query: str) -> bool:
265
  return bool(tokens & _CAPABILITY_QUERY_HINTS)
266
 
267
 
268
- def _is_biography_query(query: str) -> bool:
269
- tokens = frozenset(re.findall(r"[a-z0-9]+", query.lower()))
270
- return bool(tokens & _BIOGRAPHY_QUERY_HINTS)
271
-
272
-
273
  def _is_informative_chunk(chunk: Chunk) -> bool:
274
  """True when chunk text has enough lexical content for cross-encoder reranking."""
275
  text = (chunk.get("contextualised_text") or chunk["text"] or "").strip()
@@ -494,9 +491,9 @@ def make_retrieve_node(
494
  # chunks from the same source document via doc_id filter (no vector needed).
495
  # If chunk 4 of a blog post matched, chunks 1-3 and 5-6 are now candidates too.
496
  # This is the document-graph connectivity layer: doc_id is the edge linking chunks.
 
497
  if unique_chunks:
498
  sibling_fps: set[str] = {f"{c['metadata']['doc_id']}::{c['metadata']['section']}" for c in unique_chunks}
499
- sibling_count = 0
500
  for seed in unique_chunks[:_SIBLING_EXPAND_TOP_N]:
501
  if sibling_count >= _SIBLING_TOTAL_CAP:
502
  break
@@ -533,23 +530,6 @@ def make_retrieve_node(
533
  if not rerank_candidates:
534
  rerank_candidates = unique_chunks
535
 
536
- # Biography-focused queries should prioritize resume/CV evidence and avoid
537
- # project/blog code passages crowding out personal background facts.
538
- if _is_biography_query(retrieval_query):
539
- bio_candidates = [
540
- chunk
541
- for chunk in rerank_candidates
542
- if chunk["metadata"].get("source_type", "") in _BIO_SOURCE_TYPES
543
- ]
544
- if bio_candidates:
545
- rerank_candidates = bio_candidates
546
- writer(
547
- {
548
- "type": "status",
549
- "label": "Prioritizing resume and background sources...",
550
- }
551
- )
552
-
553
  try:
554
  reranked = await reranker.rerank(retrieval_query, rerank_candidates, top_k=10) # RC-5: raised from 7
555
  except (Exception, asyncio.CancelledError) as exc:
@@ -606,7 +586,9 @@ def make_retrieve_node(
606
  "answer": "",
607
  "retrieved_chunks": [],
608
  "reranked_chunks": [],
609
- "retrieval_attempts": attempts + 1, "top_rerank_score": top_score, }
 
 
610
 
611
  if rescue_low_confidence:
612
  writer(
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
+ from app.core.topic import _STOPWORDS
11
  from app.models.pipeline import PipelineState, Chunk
12
  from app.services.vector_store import VectorStore
13
  from app.services.embedder import Embedder
 
98
  }
99
  )
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  _NORMALISATION_STOPWORDS: frozenset[str] = frozenset(
102
  {
103
  "tell",
 
178
  }
179
  )
180
 
181
+ _FOCUS_VOCAB_BY_INITIAL: dict[str, tuple[str, ...]] = {}
182
+ for candidate in _FOCUS_VOCAB:
183
+ initial = candidate[0]
184
+ existing = _FOCUS_VOCAB_BY_INITIAL.get(initial)
185
+ if existing is None:
186
+ _FOCUS_VOCAB_BY_INITIAL[initial] = (candidate,)
187
+ else:
188
+ _FOCUS_VOCAB_BY_INITIAL[initial] = existing + (candidate,)
189
+
190
+
191
+ def _bounded_edit_distance(a: str, b: str, max_distance: int = 2) -> int:
192
+ if abs(len(a) - len(b)) > max_distance:
193
+ return max_distance + 1
194
+
195
+ previous_row = list(range(len(b) + 1))
196
+ for i, ca in enumerate(a, start=1):
197
+ current_row = [i]
198
+ row_min = current_row[0]
199
+ for j, cb in enumerate(b, start=1):
200
+ val = min(
201
+ current_row[j - 1] + 1,
202
+ previous_row[j] + 1,
203
+ previous_row[j - 1] + (0 if ca == cb else 1),
204
+ )
205
+ current_row.append(val)
206
+ if val < row_min:
207
+ row_min = val
208
+ if row_min > max_distance:
209
+ return max_distance + 1
210
+ previous_row = current_row
211
 
212
+ return previous_row[-1]
 
 
 
 
 
 
 
 
 
 
 
213
 
214
 
215
  def _best_focus_replacement(token: str) -> str | None:
216
+ candidates = _FOCUS_VOCAB_BY_INITIAL.get(token[0], ())
217
+ if not candidates:
218
+ return None
219
+
220
+ best: str | None = None
221
+ best_distance = 3
222
+ for candidate in candidates:
223
  if abs(len(token) - len(candidate)) > 1:
224
  continue
225
+ distance = _bounded_edit_distance(token, candidate, max_distance=2)
226
+ if distance < best_distance:
 
227
  best = candidate
228
+ best_distance = distance
229
+ if distance == 1:
230
+ break
231
+
232
+ return best if best_distance <= 2 else None
233
 
234
 
235
  def _normalise_focus_typos(query: str) -> str:
 
267
  return bool(tokens & _CAPABILITY_QUERY_HINTS)
268
 
269
 
 
 
 
 
 
270
  def _is_informative_chunk(chunk: Chunk) -> bool:
271
  """True when chunk text has enough lexical content for cross-encoder reranking."""
272
  text = (chunk.get("contextualised_text") or chunk["text"] or "").strip()
 
491
  # chunks from the same source document via doc_id filter (no vector needed).
492
  # If chunk 4 of a blog post matched, chunks 1-3 and 5-6 are now candidates too.
493
  # This is the document-graph connectivity layer: doc_id is the edge linking chunks.
494
+ sibling_count = 0
495
  if unique_chunks:
496
  sibling_fps: set[str] = {f"{c['metadata']['doc_id']}::{c['metadata']['section']}" for c in unique_chunks}
 
497
  for seed in unique_chunks[:_SIBLING_EXPAND_TOP_N]:
498
  if sibling_count >= _SIBLING_TOTAL_CAP:
499
  break
 
530
  if not rerank_candidates:
531
  rerank_candidates = unique_chunks
532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  try:
534
  reranked = await reranker.rerank(retrieval_query, rerank_candidates, top_k=10) # RC-5: raised from 7
535
  except (Exception, asyncio.CancelledError) as exc:
 
586
  "answer": "",
587
  "retrieved_chunks": [],
588
  "reranked_chunks": [],
589
+ "retrieval_attempts": attempts + 1,
590
+ "top_rerank_score": top_score,
591
+ }
592
 
593
  if rescue_low_confidence:
594
  writer(
app/services/loki_sink.py CHANGED
@@ -51,13 +51,25 @@ def _to_float_or_none(value: Any) -> float | None:
51
 
52
 
53
  def _build_sanitized_record(record: dict[str, Any]) -> dict[str, Any]:
54
- reranked_chunks = record.get("reranked_chunks") or []
55
  query = str(record.get("query", ""))
56
  session_id = str(record.get("session_id", ""))
57
- rerank_scores, source_types_used, chunk_count, top_chunk_doc_id = _extract_chunk_metrics(reranked_chunks)
 
 
 
 
 
 
 
 
 
58
 
59
  top_rerank_score = _to_float_or_none(record.get("top_rerank_score"))
60
- source_hit_proxy = int(top_rerank_score is not None and top_rerank_score > -1.5 and chunk_count >= 2)
 
 
 
 
61
 
62
  return {
63
  "timestamp": str(record.get("timestamp", "")),
@@ -81,9 +93,9 @@ def _build_sanitized_record(record: dict[str, Any]) -> dict[str, Any]:
81
  "query_hash": _sha_prefix(query, 16) if query else "",
82
  "chunk_count": chunk_count,
83
  "top_chunk_doc_id": top_chunk_doc_id,
84
- "source_types_used": source_types_used,
85
  "follow_up_count": len(record.get("follow_ups") or []),
86
- "rerank_scores": rerank_scores,
87
  "source_hit_proxy": source_hit_proxy,
88
  }
89
 
 
51
 
52
 
53
  def _build_sanitized_record(record: dict[str, Any]) -> dict[str, Any]:
 
54
  query = str(record.get("query", ""))
55
  session_id = str(record.get("session_id", ""))
56
+
57
+ rerank_scores = record.get("rerank_scores") or []
58
+ source_types_used = record.get("source_types_used") or []
59
+ chunk_count = int(record.get("chunk_count", 0) or 0)
60
+ top_chunk_doc_id = str(record.get("top_chunk_doc_id", "") or "")
61
+
62
+ if not rerank_scores and record.get("reranked_chunks"):
63
+ rerank_scores, source_types_used, chunk_count, top_chunk_doc_id = _extract_chunk_metrics(
64
+ record.get("reranked_chunks") or []
65
+ )
66
 
67
  top_rerank_score = _to_float_or_none(record.get("top_rerank_score"))
68
+ source_hit_proxy = int(
69
+ record.get("source_hit_proxy")
70
+ if record.get("source_hit_proxy") is not None
71
+ else (top_rerank_score is not None and top_rerank_score > -1.5 and chunk_count >= 2)
72
+ )
73
 
74
  return {
75
  "timestamp": str(record.get("timestamp", "")),
 
93
  "query_hash": _sha_prefix(query, 16) if query else "",
94
  "chunk_count": chunk_count,
95
  "top_chunk_doc_id": top_chunk_doc_id,
96
+ "source_types_used": sorted(str(source_type) for source_type in source_types_used if str(source_type)),
97
  "follow_up_count": len(record.get("follow_ups") or []),
98
+ "rerank_scores": [float(score) for score in rerank_scores if isinstance(score, (int, float))],
99
  "source_hit_proxy": source_hit_proxy,
100
  }
101
 
tests/test_cache_reference_tokens.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.pipeline.nodes.cache import _has_unresolved_reference
2
+
3
+
4
+ def test_has_unresolved_reference_ignores_spatial_and_temporal_words() -> None:
5
+ assert _has_unresolved_reference("tell me about there") is False
6
+ assert _has_unresolved_reference("what happened then") is False
7
+
8
+
9
+ def test_has_unresolved_reference_detects_pronouns_and_demonstratives() -> None:
10
+ assert _has_unresolved_reference("tell me about this") is True
11
+ assert _has_unresolved_reference("what is that project") is True
tests/test_chat_source_filtering.py CHANGED
@@ -35,3 +35,15 @@ def test_filter_sources_by_citations_no_citations_returns_input() -> None:
35
  filtered = _filter_sources_by_citations(answer, sources)
36
 
37
  assert filtered == sources
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  filtered = _filter_sources_by_citations(answer, sources)
36
 
37
  assert filtered == sources
38
+
39
+
40
+ def test_filter_sources_by_citations_discards_only_out_of_range_sources() -> None:
41
+ sources = [
42
+ {"title": "A"},
43
+ {"title": "B"},
44
+ ]
45
+ answer = "Valid [1], missing [9]."
46
+
47
+ filtered = _filter_sources_by_citations(answer, sources)
48
+
49
+ assert [s["title"] for s in filtered] == ["A"]
tests/test_generate_quality_fallback.py CHANGED
@@ -18,7 +18,7 @@ def test_low_trust_fallback_for_version_parity_queries() -> None:
18
  )
19
 
20
  assert "cannot be verified" in answer
21
- assert "[1]" in answer
22
 
23
 
24
  def test_low_trust_fallback_general_query_is_concise() -> None:
 
18
  )
19
 
20
  assert "cannot be verified" in answer
21
+ assert "[1]" not in answer
22
 
23
 
24
  def test_low_trust_fallback_general_query_is_concise() -> None:
tests/test_log_eval_privacy.py CHANGED
@@ -1,9 +1,10 @@
 
1
  import json
2
  import sqlite3
3
 
4
  import pytest
5
 
6
- from app.pipeline.nodes.log_eval import make_log_eval_node
7
 
8
 
9
  def test_log_eval_stores_chunk_metadata_without_text(tmp_path) -> None:
@@ -42,3 +43,49 @@ def test_log_eval_stores_chunk_metadata_without_text(tmp_path) -> None:
42
  assert payload and payload[0]["doc_id"] == "resume-rag"
43
  assert payload[0]["source_title"] == "Resume"
44
  assert "text" not in payload[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  import json
3
  import sqlite3
4
 
5
  import pytest
6
 
7
+ from app.pipeline.nodes.log_eval import _PENDING_TASKS, make_log_eval_node
8
 
9
 
10
  def test_log_eval_stores_chunk_metadata_without_text(tmp_path) -> None:
 
43
  assert payload and payload[0]["doc_id"] == "resume-rag"
44
  assert payload[0]["source_title"] == "Resume"
45
  assert "text" not in payload[0]
46
+
47
+
48
+ @pytest.mark.asyncio
49
+ async def test_log_eval_sends_sanitized_loki_payload(monkeypatch, tmp_path) -> None:
50
+ db_path = str(tmp_path / "interactions.db")
51
+ node = make_log_eval_node(db_path)
52
+ captured: dict = {}
53
+
54
+ async def _fake_ship_to_loki(record: dict) -> None:
55
+ await asyncio.sleep(0)
56
+ captured["record"] = record
57
+
58
+ monkeypatch.setattr("app.pipeline.nodes.log_eval.ship_to_loki", _fake_ship_to_loki)
59
+
60
+ node(
61
+ {
62
+ "session_id": "s1",
63
+ "query": "What work experience does Darshan have?",
64
+ "answer": "He worked at VK Live.",
65
+ "reranked_chunks": [
66
+ {
67
+ "text": "Phone +44 7818 975908 and email someone@example.com",
68
+ "metadata": {
69
+ "doc_id": "resume-rag",
70
+ "source_title": "Resume",
71
+ "source_type": "resume",
72
+ "section": "Work Experience",
73
+ "rerank_score": 0.9,
74
+ },
75
+ }
76
+ ],
77
+ "latency_ms": 123,
78
+ "cached": False,
79
+ "path": "rag",
80
+ "is_enumeration_query": False,
81
+ "top_rerank_score": 0.9,
82
+ }
83
+ )
84
+
85
+ await asyncio.gather(*list(_PENDING_TASKS))
86
+
87
+ record = captured.get("record")
88
+ assert record is not None
89
+ assert "reranked_chunks" not in record
90
+ assert record["chunk_count"] == 1
91
+ assert record["top_chunk_doc_id"] == "resume-rag"