Rifqi Hafizuddin commited on
Commit
145bca3
·
1 Parent(s): 83ed744

[KM-507] add different methods, now using dense cosine

Browse files
Files changed (1) hide show
  1. src/rag/retrievers/schema.py +161 -136
src/rag/retrievers/schema.py CHANGED
@@ -4,6 +4,15 @@ columns stored as source_type="document" with file_type in ("csv","xlsx").
4
  Multiple retrieval strategies are exposed for benchmarking. The active strategy
5
  used by the router is `retrieve()`, which dispatches to ACTIVE_STRATEGY.
6
  Change ACTIVE_STRATEGY at module level to switch without touching the router.
 
 
 
 
 
 
 
 
 
7
  """
8
 
9
  import asyncio
@@ -19,10 +28,9 @@ from src.rag.base import BaseRetriever, RetrievalResult
19
 
20
  logger = get_logger("schema_retriever")
21
 
22
- _SCORE_THRESHOLD = 0.75 # cosine distance — discard above this value (score < 0.25)
23
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
24
 
25
- Strategy = Literal["dense", "dense_no_threshold", "mmr", "hybrid", "hybrid_bm25"]
26
  ACTIVE_STRATEGY: Strategy = "dense_no_threshold"
27
 
28
 
@@ -31,88 +39,87 @@ class SchemaRetriever(BaseRetriever):
31
  self.vector_store = get_vector_store()
32
 
33
  # ------------------------------------------------------------------
34
- # Internal search helpers
35
  # ------------------------------------------------------------------
36
 
 
 
 
37
  async def _search_db(
38
- self, query: str, user_id: str, k: int, threshold: float | None = _SCORE_THRESHOLD
39
  ) -> list[RetrievalResult]:
40
- docs_with_scores = await self.vector_store.asimilarity_search_with_score(
41
- query=query,
42
- k=k * 4, # fetch extra to survive dedup attrition from multiple ingestion runs
43
- filter={"user_id": user_id, "source_type": "database"},
44
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return [
46
  RetrievalResult(
47
- content=doc.page_content,
48
- metadata=doc.metadata,
49
- score=1.0 - distance,
50
  source_type="database",
51
  )
52
- for doc, distance in docs_with_scores
53
- if threshold is None or distance <= threshold
54
  ]
55
 
56
  async def _search_tabular(
57
- self, query: str, user_id: str, k: int, threshold: float | None = _SCORE_THRESHOLD
58
  ) -> list[RetrievalResult]:
59
- # Fetch extra to account for post-filter attrition (non-tabular docs filtered out)
60
- docs_with_scores = await self.vector_store.asimilarity_search_with_score(
61
- query=query,
62
- k=k * 4,
63
- filter={"user_id": user_id, "source_type": "document"},
64
- )
65
- results = []
66
- for doc, distance in docs_with_scores:
67
- if doc.metadata.get("data", {}).get("file_type") not in _TABULAR_FILE_TYPES:
68
- continue
69
- if threshold is not None and distance > threshold:
70
- continue
71
- results.append(
72
- RetrievalResult(
73
- content=doc.page_content,
74
- metadata=doc.metadata,
75
- score=1.0 - distance,
76
- source_type="document",
77
- )
78
- )
79
- if len(results) >= k:
80
- break
81
- return results
82
 
83
- async def _search_db_mmr(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
84
- docs = await self.vector_store.amax_marginal_relevance_search(
85
- query=query,
86
- k=k * 4, # fetch extra to survive dedup attrition
87
- fetch_k=k * 12,
88
- filter={"user_id": user_id, "source_type": "database"},
89
- )
90
- return [
91
- RetrievalResult(
92
- content=doc.page_content,
93
- metadata=doc.metadata,
94
- score=0.0, # MMR does not return scores
95
- source_type="database",
96
- )
97
- for doc in docs
98
- ]
99
 
100
- async def _search_tabular_mmr(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
101
- docs = await self.vector_store.amax_marginal_relevance_search(
102
- query=query,
103
- k=k * 4,
104
- fetch_k=k * 12,
105
- filter={"user_id": user_id, "source_type": "document"},
106
- )
107
  results = []
108
- for doc in docs:
109
- if doc.metadata.get("data", {}).get("file_type") not in _TABULAR_FILE_TYPES:
110
- continue
111
  results.append(
112
  RetrievalResult(
113
- content=doc.page_content,
114
- metadata=doc.metadata,
115
- score=0.0,
116
  source_type="document",
117
  )
118
  )
@@ -123,9 +130,7 @@ class SchemaRetriever(BaseRetriever):
123
  async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
124
  """Full-text search over DB schema chunks using PostgreSQL tsvector.
125
 
126
- Uses plainto_tsquery (natural language, no operator syntax required).
127
- Requires the GIN index created by init_db.py on first startup after table exists.
128
- ts_rank score is only used for ordering here; RRF ignores it.
129
  """
130
  sql = text("""
131
  SELECT lpe.document, lpe.cmetadata,
@@ -155,26 +160,53 @@ class SchemaRetriever(BaseRetriever):
155
  for row in rows
156
  ]
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def _rrf_merge(
159
  self,
160
  *ranked_lists: list[RetrievalResult],
161
  k_rrf: int = 60,
162
  top_k: int = 5,
163
  ) -> list[RetrievalResult]:
164
- """Reciprocal Rank Fusion — combines ranked lists using rank positions only.
165
-
166
- Uses content prefix as dedup key so the same column appearing in multiple
167
- lists is counted once with accumulated RRF score.
168
- """
169
- scores: dict[str, float] = {}
170
- index: dict[str, RetrievalResult] = {}
171
 
172
  for ranked in ranked_lists:
173
  for rank, result in enumerate(ranked):
174
  data = result.metadata.get("data", {})
175
  key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
176
  scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
177
- # prefer the result with a real cosine score (dense leg) over ts_rank (FTS leg)
178
  if key not in index or result.score > index[key].score:
179
  index[key] = result
180
 
@@ -186,11 +218,7 @@ class SchemaRetriever(BaseRetriever):
186
  return merged[:top_k]
187
 
188
  def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
189
- """Deduplicate by (table_name, column_name), keeping highest score per unique column.
190
-
191
- Multiple ingestion runs of the same DB produce identical chunks — this collapses
192
- them so the LLM context only sees one chunk per column.
193
- """
194
  seen: dict[tuple, RetrievalResult] = {}
195
  for r in results:
196
  data = r.metadata.get("data", {})
@@ -200,75 +228,74 @@ class SchemaRetriever(BaseRetriever):
200
  return sorted(seen.values(), key=lambda r: r.score, reverse=True)
201
 
202
  # ------------------------------------------------------------------
203
- # Named strategies — call directly from benchmark / test scripts
204
  # ------------------------------------------------------------------
205
 
206
- async def dense(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
207
- """Dense similarity with score threshold. Current production default."""
 
208
  db_results, tabular_results = await asyncio.gather(
209
- self._search_db(query, user_id, k),
210
- self._search_tabular(query, user_id, k),
211
  )
212
- combined = self._dedup(db_results + tabular_results)
213
- return combined[:k]
214
 
215
- async def dense_no_threshold(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
216
- """Dense similarity without score cutoff.
217
 
218
- Use to calibrate whether the threshold is too strict/loose
219
- compare returned chunks against `dense` to see what gets filtered out.
220
  """
 
221
  db_results, tabular_results = await asyncio.gather(
222
- self._search_db(query, user_id, k, threshold=None),
223
- self._search_tabular(query, user_id, k, threshold=None),
224
  )
225
- combined = self._dedup(db_results + tabular_results)
226
- return combined[:k]
227
 
228
- async def mmr(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
229
- """MMR (Maximal Marginal Relevance) for diversity.
230
 
231
- Note: scores are 0.0 MMR does not expose similarity scores.
232
- Dedup still applied since multiple ingestion runs produce duplicate chunks.
233
  """
 
234
  db_results, tabular_results = await asyncio.gather(
235
- self._search_db_mmr(query, user_id, k),
236
- self._search_tabular_mmr(query, user_id, k),
237
  )
238
  return self._dedup(db_results + tabular_results)[:k]
239
 
240
  async def hybrid(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
241
- """RRF merge of dense + MMR results.
242
-
243
- Acts as a proxy for a true dense+FTS hybrid until a PostgreSQL tsvector
244
- GIN index is added. Dense covers semantic queries; the second ranking
245
- signal from MMR helps surface exact-name matches that dense ranks lower.
246
 
247
- To upgrade to true FTS hybrid: replace mmr() leg with _search_fts()
248
- (raw SQL using to_tsquery) and add the GIN index in init_db.py.
249
  """
250
- dense_results, mmr_results = await asyncio.gather(
251
- self.dense(query, user_id, k),
252
- self.mmr(query, user_id, k),
 
 
 
253
  )
254
- return self._rrf_merge(dense_results, mmr_results, top_k=k)
 
 
255
 
256
  async def hybrid_bm25(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
257
- """RRF merge of dense + PostgreSQL FTS (true hybrid).
258
-
259
- Dense handles semantic queries ("customer information", "revenue columns").
260
- FTS handles structural/exact terms that appear literally in chunks:
261
- [PRIMARY KEY], [FK ->], column type strings, exact column/table names.
262
 
263
- FTS results are deduped by (table_name, column_name) before merge to prevent
264
- multiple ingestion runs from accumulating RRF score unfairly.
265
- Requires GIN index on langchain_pg_embedding.document (created by init_db.py).
266
  """
267
- dense_results, fts_results = await asyncio.gather(
268
- self.dense(query, user_id, k),
 
 
269
  self._search_fts_db(query, user_id, k * 4),
270
  )
271
- return self._rrf_merge(dense_results, self._dedup(fts_results), top_k=k)
 
272
 
273
  # ------------------------------------------------------------------
274
  # Public interface — called by the router
@@ -291,17 +318,15 @@ async def benchmark(
291
  k: int = 5,
292
  strategies: list[Strategy] | None = None,
293
  ) -> dict[str, dict]:
294
- """Run multiple strategies against the same query and return timing + results.
295
-
296
- Strategies run sequentially so timings are isolated (not competing for the
297
- same DB connections). Scores and chunk content are included for manual review.
298
-
299
- Usage:
300
- from src.rag.retrievers.schema import benchmark
301
- report = await benchmark("what is the primary key of orders?", user_id="xxx")
302
- """
303
  retriever = SchemaRetriever()
304
- targets: list[Strategy] = strategies or ["dense", "dense_no_threshold", "mmr", "hybrid", "hybrid_bm25"]
 
 
 
 
 
 
305
  report: dict[str, dict] = {}
306
 
307
  for name in targets:
 
4
  Multiple retrieval strategies are exposed for benchmarking. The active strategy
5
  used by the router is `retrieve()`, which dispatches to ACTIVE_STRATEGY.
6
  Change ACTIVE_STRATEGY at module level to switch without touching the router.
7
+
8
+ All strategies embed the query exactly once, then fan out to parallel SQL legs.
9
+
10
+ Vector distance strategies:
11
+ dense_no_threshold — cosine (<=>), no score floor, always returns k chunks
12
+ dense_dot — inner product (<#>), equivalent to cosine for normalized embeddings
13
+ dense_l2 — L2/euclidean (<->), monotonic with cosine on unit-sphere vectors
14
+ hybrid — RRF merge of dense + FTS (database + tabular)
15
+ hybrid_bm25 — RRF merge of dense + FTS (database only)
16
  """
17
 
18
  import asyncio
 
28
 
29
  logger = get_logger("schema_retriever")
30
 
 
31
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
32
 
33
+ Strategy = Literal["dense_no_threshold", "dense_dot", "dense_l2", "hybrid", "hybrid_bm25"]
34
  ACTIVE_STRATEGY: Strategy = "dense_no_threshold"
35
 
36
 
 
39
  self.vector_store = get_vector_store()
40
 
41
  # ------------------------------------------------------------------
42
+ # Internal helpers
43
  # ------------------------------------------------------------------
44
 
45
+ async def _embed_query(self, query: str) -> list[float]:
46
+ return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query)
47
+
48
  async def _search_db(
49
+ self, embedding: list[float], user_id: str, k: int, operator: str = "<=>"
50
  ) -> list[RetrievalResult]:
51
+ """Vector search over database chunks. Accepts a pre-computed embedding."""
52
+ emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
53
+
54
+ if operator == "<#>":
55
+ score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
56
+ elif operator == "<->":
57
+ score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
58
+ else:
59
+ score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
60
+
61
+ sql = text(f"""
62
+ SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
63
+ FROM langchain_pg_embedding lpe
64
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
65
+ WHERE lpc.name = 'document_embeddings'
66
+ AND lpe.cmetadata->>'user_id' = :user_id
67
+ AND lpe.cmetadata->>'source_type' = 'database'
68
+ ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
69
+ LIMIT :k
70
+ """)
71
+
72
+ async with _pgvector_engine.connect() as conn:
73
+ result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
74
+ rows = result.fetchall()
75
+
76
  return [
77
  RetrievalResult(
78
+ content=row.document,
79
+ metadata=row.cmetadata,
80
+ score=float(row.score),
81
  source_type="database",
82
  )
83
+ for row in rows
 
84
  ]
85
 
86
  async def _search_tabular(
87
+ self, embedding: list[float], user_id: str, k: int, operator: str = "<=>"
88
  ) -> list[RetrievalResult]:
89
+ """Vector search over tabular document chunks. Accepts a pre-computed embedding."""
90
+ emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
91
+
92
+ if operator == "<#>":
93
+ score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
94
+ elif operator == "<->":
95
+ score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
96
+ else:
97
+ score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
98
+
99
+ sql = text(f"""
100
+ SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
101
+ FROM langchain_pg_embedding lpe
102
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
103
+ WHERE lpc.name = 'document_embeddings'
104
+ AND lpe.cmetadata->>'user_id' = :user_id
105
+ AND lpe.cmetadata->>'source_type' = 'document'
106
+ AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
107
+ OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
108
+ ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
109
+ LIMIT :k
110
+ """)
 
111
 
112
+ async with _pgvector_engine.connect() as conn:
113
+ result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
114
+ rows = result.fetchall()
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
 
 
 
 
 
 
 
116
  results = []
117
+ for row in rows:
 
 
118
  results.append(
119
  RetrievalResult(
120
+ content=row.document,
121
+ metadata=row.cmetadata,
122
+ score=float(row.score),
123
  source_type="document",
124
  )
125
  )
 
130
  async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
131
  """Full-text search over DB schema chunks using PostgreSQL tsvector.
132
 
133
+ Requires GIN index on langchain_pg_embedding.document (created by init_db.py).
 
 
134
  """
135
  sql = text("""
136
  SELECT lpe.document, lpe.cmetadata,
 
160
  for row in rows
161
  ]
162
 
163
+ async def _search_fts_tabular(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
164
+ """Full-text search over tabular document chunks using PostgreSQL tsvector."""
165
+ sql = text("""
166
+ SELECT lpe.document, lpe.cmetadata,
167
+ ts_rank(to_tsvector('english', lpe.document),
168
+ plainto_tsquery('english', :query)) AS rank
169
+ FROM langchain_pg_embedding lpe
170
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
171
+ WHERE lpc.name = 'document_embeddings'
172
+ AND lpe.cmetadata->>'user_id' = :user_id
173
+ AND lpe.cmetadata->>'source_type' = 'document'
174
+ AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
175
+ OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
176
+ AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
177
+ ORDER BY rank DESC
178
+ LIMIT :k
179
+ """)
180
+
181
+ async with _pgvector_engine.connect() as conn:
182
+ result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k})
183
+ rows = result.fetchall()
184
+
185
+ return [
186
+ RetrievalResult(
187
+ content=row.document,
188
+ metadata=row.cmetadata,
189
+ score=float(row.rank),
190
+ source_type="document",
191
+ )
192
+ for row in rows
193
+ ]
194
+
195
  def _rrf_merge(
196
  self,
197
  *ranked_lists: list[RetrievalResult],
198
  k_rrf: int = 60,
199
  top_k: int = 5,
200
  ) -> list[RetrievalResult]:
201
+ """Reciprocal Rank Fusion — combines ranked lists using rank positions only."""
202
+ scores: dict[tuple, float] = {}
203
+ index: dict[tuple, RetrievalResult] = {}
 
 
 
 
204
 
205
  for ranked in ranked_lists:
206
  for rank, result in enumerate(ranked):
207
  data = result.metadata.get("data", {})
208
  key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
209
  scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
 
210
  if key not in index or result.score > index[key].score:
211
  index[key] = result
212
 
 
218
  return merged[:top_k]
219
 
220
  def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
221
+ """Deduplicate by (table_name, column_name), keeping highest score per unique column."""
 
 
 
 
222
  seen: dict[tuple, RetrievalResult] = {}
223
  for r in results:
224
  data = r.metadata.get("data", {})
 
228
  return sorted(seen.values(), key=lambda r: r.score, reverse=True)
229
 
230
  # ------------------------------------------------------------------
231
+ # Named strategies — one embed call each, legs run in parallel
232
  # ------------------------------------------------------------------
233
 
234
+ async def dense_no_threshold(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
235
+ """Cosine similarity, no score cutoff always returns k chunks."""
236
+ embedding = await self._embed_query(query)
237
  db_results, tabular_results = await asyncio.gather(
238
+ self._search_db(embedding, user_id, k),
239
+ self._search_tabular(embedding, user_id, k),
240
  )
241
+ return self._dedup(db_results + tabular_results)[:k]
 
242
 
243
+ async def dense_dot(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
244
+ """Inner product similarity (<#>).
245
 
246
+ For L2-normalized embeddings (OpenAI), ranking is identical to cosine.
247
+ Score = raw inner product (not bounded to [0,1]).
248
  """
249
+ embedding = await self._embed_query(query)
250
  db_results, tabular_results = await asyncio.gather(
251
+ self._search_db(embedding, user_id, k, "<#>"),
252
+ self._search_tabular(embedding, user_id, k, "<#>"),
253
  )
254
+ return self._dedup(db_results + tabular_results)[:k]
 
255
 
256
+ async def dense_l2(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
257
+ """L2 (Euclidean) distance similarity (<->).
258
 
259
+ For L2-normalized embeddings (OpenAI), ranking order matches cosine.
260
+ Score = 1 / (1 + l2_distance), bounded to (0, 1].
261
  """
262
+ embedding = await self._embed_query(query)
263
  db_results, tabular_results = await asyncio.gather(
264
+ self._search_db(embedding, user_id, k, "<->"),
265
+ self._search_tabular(embedding, user_id, k, "<->"),
266
  )
267
  return self._dedup(db_results + tabular_results)[:k]
268
 
269
  async def hybrid(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
270
+ """RRF merge of dense + FTS over both database and tabular sources.
 
 
 
 
271
 
272
+ Embeds once, then runs all four legs (dense db, dense tabular, fts db,
273
+ fts tabular) in a single asyncio.gather.
274
  """
275
+ embedding = await self._embed_query(query)
276
+ db_results, tabular_results, fts_db, fts_tabular = await asyncio.gather(
277
+ self._search_db(embedding, user_id, k),
278
+ self._search_tabular(embedding, user_id, k),
279
+ self._search_fts_db(query, user_id, k * 4),
280
+ self._search_fts_tabular(query, user_id, k * 4),
281
  )
282
+ dense = self._dedup(db_results + tabular_results)[:k]
283
+ fts_all = self._dedup(fts_db + fts_tabular)
284
+ return self._rrf_merge(dense, fts_all, top_k=k)
285
 
286
  async def hybrid_bm25(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
287
+ """RRF merge of dense + FTS (database chunks only).
 
 
 
 
288
 
289
+ Embeds once, then runs dense db, dense tabular, and fts db legs in parallel.
 
 
290
  """
291
+ embedding = await self._embed_query(query)
292
+ db_results, tabular_results, fts_results = await asyncio.gather(
293
+ self._search_db(embedding, user_id, k),
294
+ self._search_tabular(embedding, user_id, k),
295
  self._search_fts_db(query, user_id, k * 4),
296
  )
297
+ dense = self._dedup(db_results + tabular_results)[:k]
298
+ return self._rrf_merge(dense, self._dedup(fts_results), top_k=k)
299
 
300
  # ------------------------------------------------------------------
301
  # Public interface — called by the router
 
318
  k: int = 5,
319
  strategies: list[Strategy] | None = None,
320
  ) -> dict[str, dict]:
321
+ """Run multiple strategies against the same query and return timing + results."""
 
 
 
 
 
 
 
 
322
  retriever = SchemaRetriever()
323
+ targets: list[Strategy] = strategies or [
324
+ "dense_no_threshold",
325
+ "dense_dot",
326
+ "dense_l2",
327
+ "hybrid",
328
+ "hybrid_bm25",
329
+ ]
330
  report: dict[str, dict] = {}
331
 
332
  for name in targets: