Rifqi Hafizuddin commited on
Commit ·
8218650
1
Parent(s): ac6b78d
[KM-507] add changes to methods
Browse files- src/db/postgres/init_db.py +20 -0
- src/rag/retrievers/schema.py +12 -5
- src/rag/router.py +1 -1
src/db/postgres/init_db.py
CHANGED
|
@@ -29,6 +29,26 @@ async def init_db():
|
|
| 29 |
"ALTER TABLE rooms ADD COLUMN IF NOT EXISTS status VARCHAR NOT NULL DEFAULT 'active'"
|
| 30 |
))
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
# GIN index for FTS on schema chunks — only created if table exists
|
| 33 |
# (langchain_pg_embedding is created by PGVector on first use, not by create_all)
|
| 34 |
await conn.execute(text("""
|
|
|
|
| 29 |
"ALTER TABLE rooms ADD COLUMN IF NOT EXISTS status VARCHAR NOT NULL DEFAULT 'active'"
|
| 30 |
))
|
| 31 |
|
| 32 |
+
# HNSW index for fast approximate vector similarity search
|
| 33 |
+
# Only created when the embedding column has explicit dimensions (HNSW requirement).
|
| 34 |
+
# atttypmod > 0 means the vector column was created with a dimension (e.g. vector(1536));
|
| 35 |
+
# atttypmod = -1 means dimensionless — HNSW would fail with "column does not have dimensions".
|
| 36 |
+
await conn.execute(text("""
|
| 37 |
+
DO $$
|
| 38 |
+
BEGIN
|
| 39 |
+
IF EXISTS (
|
| 40 |
+
SELECT FROM pg_attribute a
|
| 41 |
+
JOIN pg_class c ON c.oid = a.attrelid
|
| 42 |
+
WHERE c.relname = 'langchain_pg_embedding'
|
| 43 |
+
AND a.attname = 'embedding'
|
| 44 |
+
AND a.atttypmod > 0
|
| 45 |
+
) THEN
|
| 46 |
+
CREATE INDEX IF NOT EXISTS idx_langchain_pg_embedding_hnsw
|
| 47 |
+
ON langchain_pg_embedding USING hnsw (embedding vector_cosine_ops);
|
| 48 |
+
END IF;
|
| 49 |
+
END $$
|
| 50 |
+
"""))
|
| 51 |
+
|
| 52 |
# GIN index for FTS on schema chunks — only created if table exists
|
| 53 |
# (langchain_pg_embedding is created by PGVector on first use, not by create_all)
|
| 54 |
await conn.execute(text("""
|
src/rag/retrievers/schema.py
CHANGED
|
@@ -19,11 +19,11 @@ from src.rag.base import BaseRetriever, RetrievalResult
|
|
| 19 |
|
| 20 |
logger = get_logger("schema_retriever")
|
| 21 |
|
| 22 |
-
_SCORE_THRESHOLD = 0.
|
| 23 |
_TABULAR_FILE_TYPES = ("csv", "xlsx")
|
| 24 |
|
| 25 |
Strategy = Literal["dense", "dense_no_threshold", "mmr", "hybrid", "hybrid_bm25"]
|
| 26 |
-
ACTIVE_STRATEGY: Strategy = "
|
| 27 |
|
| 28 |
|
| 29 |
class SchemaRetriever(BaseRetriever):
|
|
@@ -171,11 +171,18 @@ class SchemaRetriever(BaseRetriever):
|
|
| 171 |
|
| 172 |
for ranked in ranked_lists:
|
| 173 |
for rank, result in enumerate(ranked):
|
| 174 |
-
|
|
|
|
| 175 |
scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
|
| 176 |
-
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
return merged[:top_k]
|
| 180 |
|
| 181 |
def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
|
|
|
|
| 19 |
|
| 20 |
logger = get_logger("schema_retriever")
|
| 21 |
|
| 22 |
+
_SCORE_THRESHOLD = 0.75 # cosine distance — discard above this value (score < 0.25)
|
| 23 |
_TABULAR_FILE_TYPES = ("csv", "xlsx")
|
| 24 |
|
| 25 |
Strategy = Literal["dense", "dense_no_threshold", "mmr", "hybrid", "hybrid_bm25"]
|
| 26 |
+
ACTIVE_STRATEGY: Strategy = "dense_no_threshold"
|
| 27 |
|
| 28 |
|
| 29 |
class SchemaRetriever(BaseRetriever):
|
|
|
|
| 171 |
|
| 172 |
for ranked in ranked_lists:
|
| 173 |
for rank, result in enumerate(ranked):
|
| 174 |
+
data = result.metadata.get("data", {})
|
| 175 |
+
key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
|
| 176 |
scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
|
| 177 |
+
# prefer the result with a real cosine score (dense leg) over ts_rank (FTS leg)
|
| 178 |
+
if key not in index or result.score > index[key].score:
|
| 179 |
+
index[key] = result
|
| 180 |
|
| 181 |
+
def _key(r: RetrievalResult) -> tuple:
|
| 182 |
+
d = r.metadata.get("data", {})
|
| 183 |
+
return (d.get("table_name"), d.get("column_name") or d.get("filename"))
|
| 184 |
+
|
| 185 |
+
merged = sorted(index.values(), key=lambda r: scores[_key(r)], reverse=True)
|
| 186 |
return merged[:top_k]
|
| 187 |
|
| 188 |
def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
|
src/rag/router.py
CHANGED
|
@@ -38,7 +38,7 @@ class RetrievalRouter:
|
|
| 38 |
query: str,
|
| 39 |
user_id: str,
|
| 40 |
source_hint: SourceHint = "both",
|
| 41 |
-
k: int =
|
| 42 |
) -> list[RetrievalResult]:
|
| 43 |
redis = await get_redis()
|
| 44 |
query_hash = hashlib.md5(query.encode()).hexdigest()
|
|
|
|
| 38 |
query: str,
|
| 39 |
user_id: str,
|
| 40 |
source_hint: SourceHint = "both",
|
| 41 |
+
k: int = 10,
|
| 42 |
) -> list[RetrievalResult]:
|
| 43 |
redis = await get_redis()
|
| 44 |
query_hash = hashlib.md5(query.encode()).hexdigest()
|