Roy commited on
Commit
d2c65d1
·
1 Parent(s): 1aa0610

Refactor vector query methods to support optional ID filtering

Browse files

- Updated BaseVectorStorage query method signature to accept optional IDs
- Modified operate.py to pass query parameter IDs to vector storage queries
- Updated PostgreSQL vector storage SQL templates to filter results by document IDs
- Removed unused parameters and simplified query logic across multiple files

lightrag/base.py CHANGED
@@ -108,9 +108,8 @@ class BaseVectorStorage(StorageNameSpace, ABC):
108
  embedding_func: EmbeddingFunc
109
  cosine_better_than_threshold: float = field(default=0.2)
110
  meta_fields: set[str] = field(default_factory=set)
111
-
112
  @abstractmethod
113
- async def query(self, query: str, top_k: int, ids: list[str] = None) -> list[dict[str, Any]]:
114
  """Query the vector storage and retrieve top_k results."""
115
 
116
  @abstractmethod
 
108
  embedding_func: EmbeddingFunc
109
  cosine_better_than_threshold: float = field(default=0.2)
110
  meta_fields: set[str] = field(default_factory=set)
 
111
  @abstractmethod
112
+ async def query(self, query: str, top_k: int, ids: list[str] | None = None) -> list[dict[str, Any]]:
113
  """Query the vector storage and retrieve top_k results."""
114
 
115
  @abstractmethod
lightrag/kg/postgres_impl.py CHANGED
@@ -439,6 +439,7 @@ class PGVectorStorage(BaseVectorStorage):
439
  "content": item["content"],
440
  "content_vector": json.dumps(item["__vector__"].tolist()),
441
  "chunk_id": item["source_id"],
 
442
  }
443
  return upsert_sql, data
444
 
@@ -452,6 +453,7 @@ class PGVectorStorage(BaseVectorStorage):
452
  "content": item["content"],
453
  "content_vector": json.dumps(item["__vector__"].tolist()),
454
  "chunk_id": item["source_id"]
 
455
  }
456
  return upsert_sql, data
457
 
@@ -494,13 +496,19 @@ class PGVectorStorage(BaseVectorStorage):
494
  await self.db.execute(upsert_sql, data)
495
 
496
  #################### query method ###############
497
- async def query(self, query: str, top_k: int, ids: list[str] = None) -> list[dict[str, Any]]:
498
  embeddings = await self.embedding_func([query])
499
  embedding = embeddings[0]
500
  embedding_string = ",".join(map(str, embedding))
501
 
 
 
 
 
 
502
  sql = SQL_TEMPLATES[self.base_namespace].format(
503
- embedding_string=embedding_string
 
504
  )
505
  params = {
506
  "workspace": self.db.workspace,
@@ -1389,7 +1397,6 @@ TABLES = {
1389
  content_vector VECTOR,
1390
  create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1391
  update_time TIMESTAMP,
1392
- document_id VARCHAR(255) NULL,
1393
  chunk_id VARCHAR(255) NULL,
1394
  CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
1395
  )"""
@@ -1404,7 +1411,6 @@ TABLES = {
1404
  content_vector VECTOR,
1405
  create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1406
  update_time TIMESTAMP,
1407
- document_id VARCHAR(255) NULL,
1408
  chunk_id VARCHAR(255) NULL,
1409
  CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
1410
  )"""
@@ -1507,21 +1513,21 @@ SQL_TEMPLATES = {
1507
  content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
1508
  """,
1509
  # SQL for VectorStorage
1510
- "entities": """SELECT entity_name FROM
1511
- (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1512
- FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
1513
- WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1514
- """,
1515
- "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
1516
- (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1517
- FROM LIGHTRAG_VDB_RELATION where workspace=$1)
1518
- WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1519
- """,
1520
- "chunks": """SELECT id FROM
1521
- (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1522
- FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
1523
- WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1524
- """,
1525
  # DROP tables
1526
  "drop_all": """
1527
  DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
@@ -1545,4 +1551,56 @@ SQL_TEMPLATES = {
1545
  "drop_vdb_relation": """
1546
  DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
1547
  """,
1548
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  "content": item["content"],
440
  "content_vector": json.dumps(item["__vector__"].tolist()),
441
  "chunk_id": item["source_id"],
442
+ #TODO: add document_id
443
  }
444
  return upsert_sql, data
445
 
 
453
  "content": item["content"],
454
  "content_vector": json.dumps(item["__vector__"].tolist()),
455
  "chunk_id": item["source_id"]
456
+ #TODO: add document_id
457
  }
458
  return upsert_sql, data
459
 
 
496
  await self.db.execute(upsert_sql, data)
497
 
498
  #################### query method ###############
499
+ async def query(self, query: str, top_k: int, ids: list[str] | None = None) -> list[dict[str, Any]]:
500
  embeddings = await self.embedding_func([query])
501
  embedding = embeddings[0]
502
  embedding_string = ",".join(map(str, embedding))
503
 
504
+ if ids:
505
+ formatted_ids = ",".join(f"'{id}'" for id in ids)
506
+ else:
507
+ formatted_ids = "NULL"
508
+
509
  sql = SQL_TEMPLATES[self.base_namespace].format(
510
+ embedding_string=embedding_string,
511
+ doc_ids=formatted_ids
512
  )
513
  params = {
514
  "workspace": self.db.workspace,
 
1397
  content_vector VECTOR,
1398
  create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1399
  update_time TIMESTAMP,
 
1400
  chunk_id VARCHAR(255) NULL,
1401
  CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
1402
  )"""
 
1411
  content_vector VECTOR,
1412
  create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1413
  update_time TIMESTAMP,
 
1414
  chunk_id VARCHAR(255) NULL,
1415
  CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
1416
  )"""
 
1513
  content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
1514
  """,
1515
  # SQL for VectorStorage
1516
+ # "entities": """SELECT entity_name FROM
1517
+ # (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1518
+ # FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
1519
+ # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1520
+ # """,
1521
+ # "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
1522
+ # (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1523
+ # FROM LIGHTRAG_VDB_RELATION where workspace=$1)
1524
+ # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1525
+ # """,
1526
+ # "chunks": """SELECT id FROM
1527
+ # (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1528
+ # FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
1529
+ # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1530
+ # """,
1531
  # DROP tables
1532
  "drop_all": """
1533
  DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
 
1551
  "drop_vdb_relation": """
1552
  DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
1553
  """,
1554
+ "relationships": """
1555
+ WITH relevant_chunks AS (
1556
+ SELECT id as chunk_id
1557
+ FROM LIGHTRAG_DOC_CHUNKS
1558
+ WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
1559
+ )
1560
+ SELECT source_id as src_id, target_id as tgt_id
1561
+ FROM (
1562
+ SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance
1563
+ FROM LIGHTRAG_VDB_RELATION r
1564
+ WHERE r.workspace=$1
1565
+ AND r.chunk_id IN (SELECT chunk_id FROM relevant_chunks)
1566
+ ) filtered
1567
+ WHERE distance>$2
1568
+ ORDER BY distance DESC
1569
+ LIMIT $3
1570
+ """,
1571
+ "entities":
1572
+ '''
1573
+ WITH relevant_chunks AS (
1574
+ SELECT id as chunk_id
1575
+ FROM LIGHTRAG_DOC_CHUNKS
1576
+ WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
1577
+ )
1578
+ SELECT entity_name FROM
1579
+ (
1580
+ SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1581
+ FROM LIGHTRAG_VDB_ENTITY
1582
+ where workspace=$1
1583
+ AND chunk_id IN (SELECT chunk_id FROM relevant_chunks)
1584
+ )
1585
+ WHERE distance>$2
1586
+ ORDER BY distance DESC
1587
+ LIMIT $3
1588
+ ''',
1589
+ 'chunks': """
1590
+ WITH relevant_chunks AS (
1591
+ SELECT id as chunk_id
1592
+ FROM LIGHTRAG_DOC_CHUNKS
1593
+ WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
1594
+ )
1595
+ SELECT id FROM
1596
+ (
1597
+ SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1598
+ FROM LIGHTRAG_DOC_CHUNKS
1599
+ where workspace=$1
1600
+ AND chunk_id IN (SELECT chunk_id FROM relevant_chunks)
1601
+ )
1602
+ WHERE distance>$2
1603
+ ORDER BY distance DESC
1604
+ LIMIT $3
1605
+ """
1606
+ }
lightrag/lightrag.py CHANGED
@@ -1243,7 +1243,6 @@ class LightRAG:
1243
  embedding_func=self.embedding_func,
1244
  ),
1245
  system_prompt=system_prompt,
1246
- ids = param.ids
1247
  )
1248
  elif param.mode == "naive":
1249
  response = await naive_query(
 
1243
  embedding_func=self.embedding_func,
1244
  ),
1245
  system_prompt=system_prompt,
 
1246
  )
1247
  elif param.mode == "naive":
1248
  response = await naive_query(
lightrag/operate.py CHANGED
@@ -602,7 +602,6 @@ async def kg_query(
602
  global_config: dict[str, str],
603
  hashing_kv: BaseKVStorage | None = None,
604
  system_prompt: str | None = None,
605
- ids: list[str] | None = None,
606
  ) -> str | AsyncIterator[str]:
607
  # Handle cache
608
  use_model_func = global_config["llm_model_func"]
@@ -650,7 +649,6 @@ async def kg_query(
650
  relationships_vdb,
651
  text_chunks_db,
652
  query_param,
653
- ids
654
  )
655
 
656
  if query_param.only_need_context:
@@ -1035,7 +1033,6 @@ async def _build_query_context(
1035
  relationships_vdb,
1036
  text_chunks_db,
1037
  query_param,
1038
- ids = ids
1039
  )
1040
  else: # hybrid mode
1041
  ll_data, hl_data = await asyncio.gather(
@@ -1104,7 +1101,9 @@ async def _get_node_data(
1104
  logger.info(
1105
  f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
1106
  )
1107
- results = await entities_vdb.query(query, top_k=query_param.top_k)
 
 
1108
  if not len(results):
1109
  return "", "", ""
1110
  # get entity information
@@ -1352,16 +1351,12 @@ async def _get_edge_data(
1352
  relationships_vdb: BaseVectorStorage,
1353
  text_chunks_db: BaseKVStorage,
1354
  query_param: QueryParam,
1355
- ids: list[str] | None = None,
1356
  ):
1357
  logger.info(
1358
  f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
1359
  )
1360
- if ids:
1361
- #TODO: add ids to the query
1362
- results = await relationships_vdb.query(keywords, top_k = query_param.top_k, ids = ids)
1363
- else:
1364
- results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
1365
 
1366
  if not len(results):
1367
  return "", "", ""
@@ -1610,7 +1605,7 @@ async def naive_query(
1610
  if cached_response is not None:
1611
  return cached_response
1612
 
1613
- results = await chunks_vdb.query(query, top_k=query_param.top_k)
1614
  if not len(results):
1615
  return PROMPTS["fail_response"]
1616
 
 
602
  global_config: dict[str, str],
603
  hashing_kv: BaseKVStorage | None = None,
604
  system_prompt: str | None = None,
 
605
  ) -> str | AsyncIterator[str]:
606
  # Handle cache
607
  use_model_func = global_config["llm_model_func"]
 
649
  relationships_vdb,
650
  text_chunks_db,
651
  query_param,
 
652
  )
653
 
654
  if query_param.only_need_context:
 
1033
  relationships_vdb,
1034
  text_chunks_db,
1035
  query_param,
 
1036
  )
1037
  else: # hybrid mode
1038
  ll_data, hl_data = await asyncio.gather(
 
1101
  logger.info(
1102
  f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
1103
  )
1104
+
1105
+ results = await entities_vdb.query(query, top_k=query_param.top_k, ids = query_param.ids)
1106
+
1107
  if not len(results):
1108
  return "", "", ""
1109
  # get entity information
 
1351
  relationships_vdb: BaseVectorStorage,
1352
  text_chunks_db: BaseKVStorage,
1353
  query_param: QueryParam,
 
1354
  ):
1355
  logger.info(
1356
  f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
1357
  )
1358
+
1359
+ results = await relationships_vdb.query(keywords, top_k = query_param.top_k, ids = query_param.ids)
 
 
 
1360
 
1361
  if not len(results):
1362
  return "", "", ""
 
1605
  if cached_response is not None:
1606
  return cached_response
1607
 
1608
+ results = await chunks_vdb.query(query, top_k=query_param.top_k, ids = query_param.ids)
1609
  if not len(results):
1610
  return PROMPTS["fail_response"]
1611