Roy
commited on
Commit
·
d2c65d1
1
Parent(s):
1aa0610
Refactor vector query methods to support optional ID filtering
Browse files- Updated BaseVectorStorage query method signature to accept optional IDs
- Modified operate.py to pass query parameter IDs to vector storage queries
- Updated PostgreSQL vector storage SQL templates to filter results by document IDs
- Removed unused parameters and simplified query logic across multiple files
- lightrag/base.py +1 -2
- lightrag/kg/postgres_impl.py +78 -20
- lightrag/lightrag.py +0 -1
- lightrag/operate.py +6 -11
lightrag/base.py
CHANGED
@@ -108,9 +108,8 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|
108 |
embedding_func: EmbeddingFunc
|
109 |
cosine_better_than_threshold: float = field(default=0.2)
|
110 |
meta_fields: set[str] = field(default_factory=set)
|
111 |
-
|
112 |
@abstractmethod
|
113 |
-
async def query(self, query: str, top_k: int, ids: list[str] = None) -> list[dict[str, Any]]:
|
114 |
"""Query the vector storage and retrieve top_k results."""
|
115 |
|
116 |
@abstractmethod
|
|
|
108 |
embedding_func: EmbeddingFunc
|
109 |
cosine_better_than_threshold: float = field(default=0.2)
|
110 |
meta_fields: set[str] = field(default_factory=set)
|
|
|
111 |
@abstractmethod
|
112 |
+
async def query(self, query: str, top_k: int, ids: list[str] | None = None) -> list[dict[str, Any]]:
|
113 |
"""Query the vector storage and retrieve top_k results."""
|
114 |
|
115 |
@abstractmethod
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -439,6 +439,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|
439 |
"content": item["content"],
|
440 |
"content_vector": json.dumps(item["__vector__"].tolist()),
|
441 |
"chunk_id": item["source_id"],
|
|
|
442 |
}
|
443 |
return upsert_sql, data
|
444 |
|
@@ -452,6 +453,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|
452 |
"content": item["content"],
|
453 |
"content_vector": json.dumps(item["__vector__"].tolist()),
|
454 |
"chunk_id": item["source_id"]
|
|
|
455 |
}
|
456 |
return upsert_sql, data
|
457 |
|
@@ -494,13 +496,19 @@ class PGVectorStorage(BaseVectorStorage):
|
|
494 |
await self.db.execute(upsert_sql, data)
|
495 |
|
496 |
#################### query method ###############
|
497 |
-
async def query(self, query: str, top_k: int, ids: list[str] = None) -> list[dict[str, Any]]:
|
498 |
embeddings = await self.embedding_func([query])
|
499 |
embedding = embeddings[0]
|
500 |
embedding_string = ",".join(map(str, embedding))
|
501 |
|
|
|
|
|
|
|
|
|
|
|
502 |
sql = SQL_TEMPLATES[self.base_namespace].format(
|
503 |
-
embedding_string=embedding_string
|
|
|
504 |
)
|
505 |
params = {
|
506 |
"workspace": self.db.workspace,
|
@@ -1389,7 +1397,6 @@ TABLES = {
|
|
1389 |
content_vector VECTOR,
|
1390 |
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
1391 |
update_time TIMESTAMP,
|
1392 |
-
document_id VARCHAR(255) NULL,
|
1393 |
chunk_id VARCHAR(255) NULL,
|
1394 |
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
|
1395 |
)"""
|
@@ -1404,7 +1411,6 @@ TABLES = {
|
|
1404 |
content_vector VECTOR,
|
1405 |
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
1406 |
update_time TIMESTAMP,
|
1407 |
-
document_id VARCHAR(255) NULL,
|
1408 |
chunk_id VARCHAR(255) NULL,
|
1409 |
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
|
1410 |
)"""
|
@@ -1507,21 +1513,21 @@ SQL_TEMPLATES = {
|
|
1507 |
content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
|
1508 |
""",
|
1509 |
# SQL for VectorStorage
|
1510 |
-
"entities": """SELECT entity_name FROM
|
1511 |
-
|
1512 |
-
|
1513 |
-
|
1514 |
-
|
1515 |
-
"relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
|
1516 |
-
|
1517 |
-
|
1518 |
-
|
1519 |
-
|
1520 |
-
"chunks": """SELECT id FROM
|
1521 |
-
|
1522 |
-
|
1523 |
-
|
1524 |
-
|
1525 |
# DROP tables
|
1526 |
"drop_all": """
|
1527 |
DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
|
@@ -1545,4 +1551,56 @@ SQL_TEMPLATES = {
|
|
1545 |
"drop_vdb_relation": """
|
1546 |
DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
|
1547 |
""",
|
1548 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
"content": item["content"],
|
440 |
"content_vector": json.dumps(item["__vector__"].tolist()),
|
441 |
"chunk_id": item["source_id"],
|
442 |
+
#TODO: add document_id
|
443 |
}
|
444 |
return upsert_sql, data
|
445 |
|
|
|
453 |
"content": item["content"],
|
454 |
"content_vector": json.dumps(item["__vector__"].tolist()),
|
455 |
"chunk_id": item["source_id"]
|
456 |
+
#TODO: add document_id
|
457 |
}
|
458 |
return upsert_sql, data
|
459 |
|
|
|
496 |
await self.db.execute(upsert_sql, data)
|
497 |
|
498 |
#################### query method ###############
|
499 |
+
async def query(self, query: str, top_k: int, ids: list[str] | None = None) -> list[dict[str, Any]]:
|
500 |
embeddings = await self.embedding_func([query])
|
501 |
embedding = embeddings[0]
|
502 |
embedding_string = ",".join(map(str, embedding))
|
503 |
|
504 |
+
if ids:
|
505 |
+
formatted_ids = ",".join(f"'{id}'" for id in ids)
|
506 |
+
else:
|
507 |
+
formatted_ids = "NULL"
|
508 |
+
|
509 |
sql = SQL_TEMPLATES[self.base_namespace].format(
|
510 |
+
embedding_string=embedding_string,
|
511 |
+
doc_ids=formatted_ids
|
512 |
)
|
513 |
params = {
|
514 |
"workspace": self.db.workspace,
|
|
|
1397 |
content_vector VECTOR,
|
1398 |
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
1399 |
update_time TIMESTAMP,
|
|
|
1400 |
chunk_id VARCHAR(255) NULL,
|
1401 |
CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
|
1402 |
)"""
|
|
|
1411 |
content_vector VECTOR,
|
1412 |
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
1413 |
update_time TIMESTAMP,
|
|
|
1414 |
chunk_id VARCHAR(255) NULL,
|
1415 |
CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
|
1416 |
)"""
|
|
|
1513 |
content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
|
1514 |
""",
|
1515 |
# SQL for VectorStorage
|
1516 |
+
# "entities": """SELECT entity_name FROM
|
1517 |
+
# (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1518 |
+
# FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
|
1519 |
+
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
1520 |
+
# """,
|
1521 |
+
# "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
|
1522 |
+
# (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1523 |
+
# FROM LIGHTRAG_VDB_RELATION where workspace=$1)
|
1524 |
+
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
1525 |
+
# """,
|
1526 |
+
# "chunks": """SELECT id FROM
|
1527 |
+
# (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1528 |
+
# FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
1529 |
+
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
1530 |
+
# """,
|
1531 |
# DROP tables
|
1532 |
"drop_all": """
|
1533 |
DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
|
|
|
1551 |
"drop_vdb_relation": """
|
1552 |
DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
|
1553 |
""",
|
1554 |
+
"relationships": """
|
1555 |
+
WITH relevant_chunks AS (
|
1556 |
+
SELECT id as chunk_id
|
1557 |
+
FROM LIGHTRAG_DOC_CHUNKS
|
1558 |
+
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
1559 |
+
)
|
1560 |
+
SELECT source_id as src_id, target_id as tgt_id
|
1561 |
+
FROM (
|
1562 |
+
SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance
|
1563 |
+
FROM LIGHTRAG_VDB_RELATION r
|
1564 |
+
WHERE r.workspace=$1
|
1565 |
+
AND r.chunk_id IN (SELECT chunk_id FROM relevant_chunks)
|
1566 |
+
) filtered
|
1567 |
+
WHERE distance>$2
|
1568 |
+
ORDER BY distance DESC
|
1569 |
+
LIMIT $3
|
1570 |
+
""",
|
1571 |
+
"entities":
|
1572 |
+
'''
|
1573 |
+
WITH relevant_chunks AS (
|
1574 |
+
SELECT id as chunk_id
|
1575 |
+
FROM LIGHTRAG_DOC_CHUNKS
|
1576 |
+
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
1577 |
+
)
|
1578 |
+
SELECT entity_name FROM
|
1579 |
+
(
|
1580 |
+
SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1581 |
+
FROM LIGHTRAG_VDB_ENTITY
|
1582 |
+
where workspace=$1
|
1583 |
+
AND chunk_id IN (SELECT chunk_id FROM relevant_chunks)
|
1584 |
+
)
|
1585 |
+
WHERE distance>$2
|
1586 |
+
ORDER BY distance DESC
|
1587 |
+
LIMIT $3
|
1588 |
+
''',
|
1589 |
+
'chunks': """
|
1590 |
+
WITH relevant_chunks AS (
|
1591 |
+
SELECT id as chunk_id
|
1592 |
+
FROM LIGHTRAG_DOC_CHUNKS
|
1593 |
+
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
1594 |
+
)
|
1595 |
+
SELECT id FROM
|
1596 |
+
(
|
1597 |
+
SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1598 |
+
FROM LIGHTRAG_DOC_CHUNKS
|
1599 |
+
where workspace=$1
|
1600 |
+
AND chunk_id IN (SELECT chunk_id FROM relevant_chunks)
|
1601 |
+
)
|
1602 |
+
WHERE distance>$2
|
1603 |
+
ORDER BY distance DESC
|
1604 |
+
LIMIT $3
|
1605 |
+
"""
|
1606 |
+
}
|
lightrag/lightrag.py
CHANGED
@@ -1243,7 +1243,6 @@ class LightRAG:
|
|
1243 |
embedding_func=self.embedding_func,
|
1244 |
),
|
1245 |
system_prompt=system_prompt,
|
1246 |
-
ids = param.ids
|
1247 |
)
|
1248 |
elif param.mode == "naive":
|
1249 |
response = await naive_query(
|
|
|
1243 |
embedding_func=self.embedding_func,
|
1244 |
),
|
1245 |
system_prompt=system_prompt,
|
|
|
1246 |
)
|
1247 |
elif param.mode == "naive":
|
1248 |
response = await naive_query(
|
lightrag/operate.py
CHANGED
@@ -602,7 +602,6 @@ async def kg_query(
|
|
602 |
global_config: dict[str, str],
|
603 |
hashing_kv: BaseKVStorage | None = None,
|
604 |
system_prompt: str | None = None,
|
605 |
-
ids: list[str] | None = None,
|
606 |
) -> str | AsyncIterator[str]:
|
607 |
# Handle cache
|
608 |
use_model_func = global_config["llm_model_func"]
|
@@ -650,7 +649,6 @@ async def kg_query(
|
|
650 |
relationships_vdb,
|
651 |
text_chunks_db,
|
652 |
query_param,
|
653 |
-
ids
|
654 |
)
|
655 |
|
656 |
if query_param.only_need_context:
|
@@ -1035,7 +1033,6 @@ async def _build_query_context(
|
|
1035 |
relationships_vdb,
|
1036 |
text_chunks_db,
|
1037 |
query_param,
|
1038 |
-
ids = ids
|
1039 |
)
|
1040 |
else: # hybrid mode
|
1041 |
ll_data, hl_data = await asyncio.gather(
|
@@ -1104,7 +1101,9 @@ async def _get_node_data(
|
|
1104 |
logger.info(
|
1105 |
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
1106 |
)
|
1107 |
-
|
|
|
|
|
1108 |
if not len(results):
|
1109 |
return "", "", ""
|
1110 |
# get entity information
|
@@ -1352,16 +1351,12 @@ async def _get_edge_data(
|
|
1352 |
relationships_vdb: BaseVectorStorage,
|
1353 |
text_chunks_db: BaseKVStorage,
|
1354 |
query_param: QueryParam,
|
1355 |
-
ids: list[str] | None = None,
|
1356 |
):
|
1357 |
logger.info(
|
1358 |
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
1359 |
)
|
1360 |
-
|
1361 |
-
|
1362 |
-
results = await relationships_vdb.query(keywords, top_k = query_param.top_k, ids = ids)
|
1363 |
-
else:
|
1364 |
-
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
1365 |
|
1366 |
if not len(results):
|
1367 |
return "", "", ""
|
@@ -1610,7 +1605,7 @@ async def naive_query(
|
|
1610 |
if cached_response is not None:
|
1611 |
return cached_response
|
1612 |
|
1613 |
-
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
1614 |
if not len(results):
|
1615 |
return PROMPTS["fail_response"]
|
1616 |
|
|
|
602 |
global_config: dict[str, str],
|
603 |
hashing_kv: BaseKVStorage | None = None,
|
604 |
system_prompt: str | None = None,
|
|
|
605 |
) -> str | AsyncIterator[str]:
|
606 |
# Handle cache
|
607 |
use_model_func = global_config["llm_model_func"]
|
|
|
649 |
relationships_vdb,
|
650 |
text_chunks_db,
|
651 |
query_param,
|
|
|
652 |
)
|
653 |
|
654 |
if query_param.only_need_context:
|
|
|
1033 |
relationships_vdb,
|
1034 |
text_chunks_db,
|
1035 |
query_param,
|
|
|
1036 |
)
|
1037 |
else: # hybrid mode
|
1038 |
ll_data, hl_data = await asyncio.gather(
|
|
|
1101 |
logger.info(
|
1102 |
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
1103 |
)
|
1104 |
+
|
1105 |
+
results = await entities_vdb.query(query, top_k=query_param.top_k, ids = query_param.ids)
|
1106 |
+
|
1107 |
if not len(results):
|
1108 |
return "", "", ""
|
1109 |
# get entity information
|
|
|
1351 |
relationships_vdb: BaseVectorStorage,
|
1352 |
text_chunks_db: BaseKVStorage,
|
1353 |
query_param: QueryParam,
|
|
|
1354 |
):
|
1355 |
logger.info(
|
1356 |
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
1357 |
)
|
1358 |
+
|
1359 |
+
results = await relationships_vdb.query(keywords, top_k = query_param.top_k, ids = query_param.ids)
|
|
|
|
|
|
|
1360 |
|
1361 |
if not len(results):
|
1362 |
return "", "", ""
|
|
|
1605 |
if cached_response is not None:
|
1606 |
return cached_response
|
1607 |
|
1608 |
+
results = await chunks_vdb.query(query, top_k=query_param.top_k, ids = query_param.ids)
|
1609 |
if not len(results):
|
1610 |
return PROMPTS["fail_response"]
|
1611 |
|