Rifqi Hafizuddin commited on
Commit
ac6b78d
Β·
1 Parent(s): e9f2a26

[KM-507] add multiple retrieval method to compare (dense, mmr, bm25, hybrid)

Browse files
src/db/postgres/init_db.py CHANGED
@@ -28,3 +28,18 @@ async def init_db():
28
  await conn.execute(text(
29
  "ALTER TABLE rooms ADD COLUMN IF NOT EXISTS status VARCHAR NOT NULL DEFAULT 'active'"
30
  ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  await conn.execute(text(
29
  "ALTER TABLE rooms ADD COLUMN IF NOT EXISTS status VARCHAR NOT NULL DEFAULT 'active'"
30
  ))
31
+
32
+ # GIN index for FTS on schema chunks β€” only created if table exists
33
+ # (langchain_pg_embedding is created by PGVector on first use, not by create_all)
34
+ await conn.execute(text("""
35
+ DO $$
36
+ BEGIN
37
+ IF EXISTS (
38
+ SELECT FROM information_schema.tables
39
+ WHERE table_name = 'langchain_pg_embedding'
40
+ ) THEN
41
+ CREATE INDEX IF NOT EXISTS idx_langchain_pg_embedding_fts
42
+ ON langchain_pg_embedding USING GIN (to_tsvector('english', document));
43
+ END IF;
44
+ END $$
45
+ """))
src/rag/retrievers/schema.py CHANGED
@@ -1,86 +1,317 @@
1
  """Schema retriever β€” handles DB schemas (source_type="database") and tabular file
2
  columns stored as source_type="document" with file_type in ("csv","xlsx").
3
 
4
- Strategy: similarity search with score threshold on two metadata shapes,
5
- run in parallel, merged and re-ranked by score.
 
6
  """
7
 
8
  import asyncio
 
 
9
 
 
 
 
10
  from src.db.postgres.vector_store import get_vector_store
11
  from src.middlewares.logging import get_logger
12
  from src.rag.base import BaseRetriever, RetrievalResult
13
 
14
  logger = get_logger("schema_retriever")
15
 
16
- _SCORE_THRESHOLD = 0.45 # cosine distance β€” discard above this value
17
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
18
 
 
 
 
19
 
20
  class SchemaRetriever(BaseRetriever):
21
  def __init__(self):
22
  self.vector_store = get_vector_store()
23
 
24
- async def _search_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
25
- """Retrieve DB schema chunks (source_type="database")."""
 
 
 
 
 
26
  docs_with_scores = await self.vector_store.asimilarity_search_with_score(
27
  query=query,
28
- k=k,
29
  filter={"user_id": user_id, "source_type": "database"},
30
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  results = []
32
  for doc, distance in docs_with_scores:
33
- if distance <= _SCORE_THRESHOLD:
34
- results.append(
35
- RetrievalResult(
36
- content=doc.page_content,
37
- metadata=doc.metadata,
38
- score=1.0 - distance,
39
- source_type="database",
40
- )
 
 
41
  )
 
 
 
42
  return results
43
 
44
- async def _search_tabular(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
45
- """Retrieve CSV/XLSX column chunks (source_type="document", file_type=csv|xlsx)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  results = []
47
- for file_type in _TABULAR_FILE_TYPES:
48
- docs_with_scores = await self.vector_store.asimilarity_search_with_score(
49
- query=query,
50
- k=k,
51
- filter={
52
- "user_id": user_id,
53
- "source_type": "document",
54
- "data": {"file_type": file_type},
55
- },
 
56
  )
57
- for doc, distance in docs_with_scores:
58
- if distance <= _SCORE_THRESHOLD:
59
- results.append(
60
- RetrievalResult(
61
- content=doc.page_content,
62
- metadata=doc.metadata,
63
- score=1.0 - distance,
64
- source_type="document",
65
- )
66
- )
67
  return results
68
 
69
- async def retrieve(
70
- self, query: str, user_id: str, k: int = 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ) -> list[RetrievalResult]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  db_results, tabular_results = await asyncio.gather(
73
  self._search_db(query, user_id, k),
74
  self._search_tabular(query, user_id, k),
75
  )
76
- combined = db_results + tabular_results
77
- combined.sort(key=lambda r: r.score, reverse=True)
78
- logger.info(
79
- "schema retrieval",
80
- db_chunks=len(db_results),
81
- tabular_chunks=len(tabular_results),
 
 
 
 
 
 
82
  )
 
83
  return combined[:k]
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  schema_retriever = SchemaRetriever()
 
1
  """Schema retriever β€” handles DB schemas (source_type="database") and tabular file
2
  columns stored as source_type="document" with file_type in ("csv","xlsx").
3
 
4
+ Multiple retrieval strategies are exposed for benchmarking. The active strategy
5
+ used by the router is `retrieve()`, which dispatches to ACTIVE_STRATEGY.
6
+ Change ACTIVE_STRATEGY at module level to switch without touching the router.
7
  """
8
 
9
  import asyncio
10
+ import time
11
+ from typing import Literal
12
 
13
+ from sqlalchemy import text
14
+
15
+ from src.db.postgres.connection import _pgvector_engine
16
  from src.db.postgres.vector_store import get_vector_store
17
  from src.middlewares.logging import get_logger
18
  from src.rag.base import BaseRetriever, RetrievalResult
19
 
20
  logger = get_logger("schema_retriever")
21
 
22
+ _SCORE_THRESHOLD = 0.60 # cosine distance β€” discard above this value (score < 0.40)
23
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
24
 
25
+ Strategy = Literal["dense", "dense_no_threshold", "mmr", "hybrid", "hybrid_bm25"]
26
+ ACTIVE_STRATEGY: Strategy = "hybrid_bm25"
27
+
28
 
29
  class SchemaRetriever(BaseRetriever):
30
  def __init__(self):
31
  self.vector_store = get_vector_store()
32
 
33
+ # ------------------------------------------------------------------
34
+ # Internal search helpers
35
+ # ------------------------------------------------------------------
36
+
37
+ async def _search_db(
38
+ self, query: str, user_id: str, k: int, threshold: float | None = _SCORE_THRESHOLD
39
+ ) -> list[RetrievalResult]:
40
  docs_with_scores = await self.vector_store.asimilarity_search_with_score(
41
  query=query,
42
+ k=k * 4, # fetch extra to survive dedup attrition from multiple ingestion runs
43
  filter={"user_id": user_id, "source_type": "database"},
44
  )
45
+ return [
46
+ RetrievalResult(
47
+ content=doc.page_content,
48
+ metadata=doc.metadata,
49
+ score=1.0 - distance,
50
+ source_type="database",
51
+ )
52
+ for doc, distance in docs_with_scores
53
+ if threshold is None or distance <= threshold
54
+ ]
55
+
56
+ async def _search_tabular(
57
+ self, query: str, user_id: str, k: int, threshold: float | None = _SCORE_THRESHOLD
58
+ ) -> list[RetrievalResult]:
59
+ # Fetch extra to account for post-filter attrition (non-tabular docs filtered out)
60
+ docs_with_scores = await self.vector_store.asimilarity_search_with_score(
61
+ query=query,
62
+ k=k * 4,
63
+ filter={"user_id": user_id, "source_type": "document"},
64
+ )
65
  results = []
66
  for doc, distance in docs_with_scores:
67
+ if doc.metadata.get("data", {}).get("file_type") not in _TABULAR_FILE_TYPES:
68
+ continue
69
+ if threshold is not None and distance > threshold:
70
+ continue
71
+ results.append(
72
+ RetrievalResult(
73
+ content=doc.page_content,
74
+ metadata=doc.metadata,
75
+ score=1.0 - distance,
76
+ source_type="document",
77
  )
78
+ )
79
+ if len(results) >= k:
80
+ break
81
  return results
82
 
83
+ async def _search_db_mmr(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
84
+ docs = await self.vector_store.amax_marginal_relevance_search(
85
+ query=query,
86
+ k=k * 4, # fetch extra to survive dedup attrition
87
+ fetch_k=k * 12,
88
+ filter={"user_id": user_id, "source_type": "database"},
89
+ )
90
+ return [
91
+ RetrievalResult(
92
+ content=doc.page_content,
93
+ metadata=doc.metadata,
94
+ score=0.0, # MMR does not return scores
95
+ source_type="database",
96
+ )
97
+ for doc in docs
98
+ ]
99
+
100
+ async def _search_tabular_mmr(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
101
+ docs = await self.vector_store.amax_marginal_relevance_search(
102
+ query=query,
103
+ k=k * 4,
104
+ fetch_k=k * 12,
105
+ filter={"user_id": user_id, "source_type": "document"},
106
+ )
107
  results = []
108
+ for doc in docs:
109
+ if doc.metadata.get("data", {}).get("file_type") not in _TABULAR_FILE_TYPES:
110
+ continue
111
+ results.append(
112
+ RetrievalResult(
113
+ content=doc.page_content,
114
+ metadata=doc.metadata,
115
+ score=0.0,
116
+ source_type="document",
117
+ )
118
  )
119
+ if len(results) >= k:
120
+ break
 
 
 
 
 
 
 
 
121
  return results
122
 
123
+ async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
124
+ """Full-text search over DB schema chunks using PostgreSQL tsvector.
125
+
126
+ Uses plainto_tsquery (natural language, no operator syntax required).
127
+ Requires the GIN index created by init_db.py on first startup after table exists.
128
+ ts_rank score is only used for ordering here; RRF ignores it.
129
+ """
130
+ sql = text("""
131
+ SELECT lpe.document, lpe.cmetadata,
132
+ ts_rank(to_tsvector('english', lpe.document),
133
+ plainto_tsquery('english', :query)) AS rank
134
+ FROM langchain_pg_embedding lpe
135
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
136
+ WHERE lpc.name = 'document_embeddings'
137
+ AND lpe.cmetadata->>'user_id' = :user_id
138
+ AND lpe.cmetadata->>'source_type' = 'database'
139
+ AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
140
+ ORDER BY rank DESC
141
+ LIMIT :k
142
+ """)
143
+
144
+ async with _pgvector_engine.connect() as conn:
145
+ result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k})
146
+ rows = result.fetchall()
147
+
148
+ return [
149
+ RetrievalResult(
150
+ content=row.document,
151
+ metadata=row.cmetadata,
152
+ score=float(row.rank),
153
+ source_type="database",
154
+ )
155
+ for row in rows
156
+ ]
157
+
158
+ def _rrf_merge(
159
+ self,
160
+ *ranked_lists: list[RetrievalResult],
161
+ k_rrf: int = 60,
162
+ top_k: int = 5,
163
  ) -> list[RetrievalResult]:
164
+ """Reciprocal Rank Fusion β€” combines ranked lists using rank positions only.
165
+
166
+ Uses content prefix as dedup key so the same column appearing in multiple
167
+ lists is counted once with accumulated RRF score.
168
+ """
169
+ scores: dict[str, float] = {}
170
+ index: dict[str, RetrievalResult] = {}
171
+
172
+ for ranked in ranked_lists:
173
+ for rank, result in enumerate(ranked):
174
+ key = result.content[:120]
175
+ scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
176
+ index[key] = result
177
+
178
+ merged = sorted(index.values(), key=lambda r: scores[r.content[:120]], reverse=True)
179
+ return merged[:top_k]
180
+
181
+ def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
182
+ """Deduplicate by (table_name, column_name), keeping highest score per unique column.
183
+
184
+ Multiple ingestion runs of the same DB produce identical chunks β€” this collapses
185
+ them so the LLM context only sees one chunk per column.
186
+ """
187
+ seen: dict[tuple, RetrievalResult] = {}
188
+ for r in results:
189
+ data = r.metadata.get("data", {})
190
+ key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
191
+ if key not in seen or r.score > seen[key].score:
192
+ seen[key] = r
193
+ return sorted(seen.values(), key=lambda r: r.score, reverse=True)
194
+
195
+ # ------------------------------------------------------------------
196
+ # Named strategies β€” call directly from benchmark / test scripts
197
+ # ------------------------------------------------------------------
198
+
199
+ async def dense(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
200
+ """Dense similarity with score threshold. Current production default."""
201
  db_results, tabular_results = await asyncio.gather(
202
  self._search_db(query, user_id, k),
203
  self._search_tabular(query, user_id, k),
204
  )
205
+ combined = self._dedup(db_results + tabular_results)
206
+ return combined[:k]
207
+
208
+ async def dense_no_threshold(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
209
+ """Dense similarity without score cutoff.
210
+
211
+ Use to calibrate whether the threshold is too strict/loose β€”
212
+ compare returned chunks against `dense` to see what gets filtered out.
213
+ """
214
+ db_results, tabular_results = await asyncio.gather(
215
+ self._search_db(query, user_id, k, threshold=None),
216
+ self._search_tabular(query, user_id, k, threshold=None),
217
  )
218
+ combined = self._dedup(db_results + tabular_results)
219
  return combined[:k]
220
 
221
+ async def mmr(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
222
+ """MMR (Maximal Marginal Relevance) for diversity.
223
+
224
+ Note: scores are 0.0 β€” MMR does not expose similarity scores.
225
+ Dedup still applied since multiple ingestion runs produce duplicate chunks.
226
+ """
227
+ db_results, tabular_results = await asyncio.gather(
228
+ self._search_db_mmr(query, user_id, k),
229
+ self._search_tabular_mmr(query, user_id, k),
230
+ )
231
+ return self._dedup(db_results + tabular_results)[:k]
232
+
233
+ async def hybrid(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
234
+ """RRF merge of dense + MMR results.
235
+
236
+ Acts as a proxy for a true dense+FTS hybrid until a PostgreSQL tsvector
237
+ GIN index is added. Dense covers semantic queries; the second ranking
238
+ signal from MMR helps surface exact-name matches that dense ranks lower.
239
+
240
+ To upgrade to true FTS hybrid: replace mmr() leg with _search_fts()
241
+ (raw SQL using to_tsquery) and add the GIN index in init_db.py.
242
+ """
243
+ dense_results, mmr_results = await asyncio.gather(
244
+ self.dense(query, user_id, k),
245
+ self.mmr(query, user_id, k),
246
+ )
247
+ return self._rrf_merge(dense_results, mmr_results, top_k=k)
248
+
249
+ async def hybrid_bm25(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
250
+ """RRF merge of dense + PostgreSQL FTS (true hybrid).
251
+
252
+ Dense handles semantic queries ("customer information", "revenue columns").
253
+ FTS handles structural/exact terms that appear literally in chunks:
254
+ [PRIMARY KEY], [FK ->], column type strings, exact column/table names.
255
+
256
+ FTS results are deduped by (table_name, column_name) before merge to prevent
257
+ multiple ingestion runs from accumulating RRF score unfairly.
258
+ Requires GIN index on langchain_pg_embedding.document (created by init_db.py).
259
+ """
260
+ dense_results, fts_results = await asyncio.gather(
261
+ self.dense(query, user_id, k),
262
+ self._search_fts_db(query, user_id, k * 4),
263
+ )
264
+ return self._rrf_merge(dense_results, self._dedup(fts_results), top_k=k)
265
+
266
+ # ------------------------------------------------------------------
267
+ # Public interface β€” called by the router
268
+ # ------------------------------------------------------------------
269
+
270
+ async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
271
+ strategy_fn = getattr(self, ACTIVE_STRATEGY)
272
+ results = await strategy_fn(query, user_id, k)
273
+ logger.info("schema retrieval", strategy=ACTIVE_STRATEGY, count=len(results))
274
+ return results
275
+
276
+
277
+ # ------------------------------------------------------------------
278
+ # Benchmark helper β€” import in test scripts
279
+ # ------------------------------------------------------------------
280
+
281
+ async def benchmark(
282
+ query: str,
283
+ user_id: str,
284
+ k: int = 5,
285
+ strategies: list[Strategy] | None = None,
286
+ ) -> dict[str, dict]:
287
+ """Run multiple strategies against the same query and return timing + results.
288
+
289
+ Strategies run sequentially so timings are isolated (not competing for the
290
+ same DB connections). Scores and chunk content are included for manual review.
291
+
292
+ Usage:
293
+ from src.rag.retrievers.schema import benchmark
294
+ report = await benchmark("what is the primary key of orders?", user_id="xxx")
295
+ """
296
+ retriever = SchemaRetriever()
297
+ targets: list[Strategy] = strategies or ["dense", "dense_no_threshold", "mmr", "hybrid", "hybrid_bm25"]
298
+ report: dict[str, dict] = {}
299
+
300
+ for name in targets:
301
+ fn = getattr(retriever, name)
302
+ t0 = time.perf_counter()
303
+ chunks = await fn(query, user_id, k)
304
+ elapsed_ms = round((time.perf_counter() - t0) * 1000)
305
+
306
+ total_chars = sum(len(r.content) for r in chunks)
307
+ report[name] = {
308
+ "chunks": len(chunks),
309
+ "estimated_tokens": total_chars // 4,
310
+ "elapsed_ms": elapsed_ms,
311
+ "results": chunks,
312
+ }
313
+
314
+ return report
315
+
316
 
317
  schema_retriever = SchemaRetriever()