Rifqi Hafizuddin commited on
Commit
8218650
·
1 Parent(s): ac6b78d

[KM-507] add changes to methods

Browse files
src/db/postgres/init_db.py CHANGED
@@ -29,6 +29,26 @@ async def init_db():
29
  "ALTER TABLE rooms ADD COLUMN IF NOT EXISTS status VARCHAR NOT NULL DEFAULT 'active'"
30
  ))
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # GIN index for FTS on schema chunks — only created if table exists
33
  # (langchain_pg_embedding is created by PGVector on first use, not by create_all)
34
  await conn.execute(text("""
 
29
  "ALTER TABLE rooms ADD COLUMN IF NOT EXISTS status VARCHAR NOT NULL DEFAULT 'active'"
30
  ))
31
 
32
+ # HNSW index for fast approximate vector similarity search
33
+ # Only created when the embedding column has explicit dimensions (HNSW requirement).
34
+ # atttypmod > 0 means the vector column was created with a dimension (e.g. vector(1536));
35
+ # atttypmod = -1 means dimensionless — HNSW would fail with "column does not have dimensions".
36
+ await conn.execute(text("""
37
+ DO $$
38
+ BEGIN
39
+ IF EXISTS (
40
+ SELECT FROM pg_attribute a
41
+ JOIN pg_class c ON c.oid = a.attrelid
42
+ WHERE c.relname = 'langchain_pg_embedding'
43
+ AND a.attname = 'embedding'
44
+ AND a.atttypmod > 0
45
+ ) THEN
46
+ CREATE INDEX IF NOT EXISTS idx_langchain_pg_embedding_hnsw
47
+ ON langchain_pg_embedding USING hnsw (embedding vector_cosine_ops);
48
+ END IF;
49
+ END $$
50
+ """))
51
+
52
  # GIN index for FTS on schema chunks — only created if table exists
53
  # (langchain_pg_embedding is created by PGVector on first use, not by create_all)
54
  await conn.execute(text("""
src/rag/retrievers/schema.py CHANGED
@@ -19,11 +19,11 @@ from src.rag.base import BaseRetriever, RetrievalResult
19
 
20
  logger = get_logger("schema_retriever")
21
 
22
- _SCORE_THRESHOLD = 0.60 # cosine distance — discard above this value (score < 0.40)
23
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
24
 
25
  Strategy = Literal["dense", "dense_no_threshold", "mmr", "hybrid", "hybrid_bm25"]
26
- ACTIVE_STRATEGY: Strategy = "hybrid_bm25"
27
 
28
 
29
  class SchemaRetriever(BaseRetriever):
@@ -171,11 +171,18 @@ class SchemaRetriever(BaseRetriever):
171
 
172
  for ranked in ranked_lists:
173
  for rank, result in enumerate(ranked):
174
- key = result.content[:120]
 
175
  scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
176
- index[key] = result
 
 
177
 
178
- merged = sorted(index.values(), key=lambda r: scores[r.content[:120]], reverse=True)
 
 
 
 
179
  return merged[:top_k]
180
 
181
  def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
 
19
 
20
  logger = get_logger("schema_retriever")
21
 
22
+ _SCORE_THRESHOLD = 0.75 # cosine distance — discard above this value (score < 0.25)
23
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
24
 
25
  Strategy = Literal["dense", "dense_no_threshold", "mmr", "hybrid", "hybrid_bm25"]
26
+ ACTIVE_STRATEGY: Strategy = "dense_no_threshold"
27
 
28
 
29
  class SchemaRetriever(BaseRetriever):
 
171
 
172
  for ranked in ranked_lists:
173
  for rank, result in enumerate(ranked):
174
+ data = result.metadata.get("data", {})
175
+ key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
176
  scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
177
+ # prefer the result with a real cosine score (dense leg) over ts_rank (FTS leg)
178
+ if key not in index or result.score > index[key].score:
179
+ index[key] = result
180
 
181
+ def _key(r: RetrievalResult) -> tuple:
182
+ d = r.metadata.get("data", {})
183
+ return (d.get("table_name"), d.get("column_name") or d.get("filename"))
184
+
185
+ merged = sorted(index.values(), key=lambda r: scores[_key(r)], reverse=True)
186
  return merged[:top_k]
187
 
188
  def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
src/rag/router.py CHANGED
@@ -38,7 +38,7 @@ class RetrievalRouter:
38
  query: str,
39
  user_id: str,
40
  source_hint: SourceHint = "both",
41
- k: int = 5,
42
  ) -> list[RetrievalResult]:
43
  redis = await get_redis()
44
  query_hash = hashlib.md5(query.encode()).hexdigest()
 
38
  query: str,
39
  user_id: str,
40
  source_hint: SourceHint = "both",
41
+ k: int = 10,
42
  ) -> list[RetrievalResult]:
43
  redis = await get_redis()
44
  query_hash = hashlib.md5(query.encode()).hexdigest()