frederikhendrix commited on
Commit
f43436e
·
1 Parent(s): a450594

get_node added and all to base.py and to neo4j_impl.py file

Browse files
Files changed (3) hide show
  1. lightrag/base.py +20 -0
  2. lightrag/kg/neo4j_impl.py +156 -0
  3. lightrag/operate.py +60 -51
lightrag/base.py CHANGED
@@ -309,6 +309,26 @@ class BaseGraphStorage(StorageNameSpace, ABC):
309
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
310
  """Upsert a node into the graph."""
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  @abstractmethod
313
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
314
  """Upsert an edge into the graph."""
 
309
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
310
  """Upsert a node into the graph."""
311
 
312
+ @abstractmethod
313
+ async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
314
+ """Get nodes as a batch using UNWIND"""
315
+
316
+ @abstractmethod
317
+ async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
318
+ """Node degrees as a batch using UNWIND"""
319
+
320
+ @abstractmethod
321
+ async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
322
+ """Edge degrees as a batch using UNWIND also uses node_degrees_batch"""
323
+
324
+ @abstractmethod
325
+ async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
326
+ """Get edges as a batch using UNWIND"""
327
+
328
+ @abstractmethod
329
+ async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
330
+ """"Get nodes edges as a batch using UNWIND"""
331
+
332
  @abstractmethod
333
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
334
  """Upsert an edge into the graph."""
lightrag/kg/neo4j_impl.py CHANGED
@@ -314,6 +314,37 @@ class Neo4JStorage(BaseGraphStorage):
314
  logger.error(f"Error getting node for {node_id}: {str(e)}")
315
  raise
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  async def node_degree(self, node_id: str) -> int:
318
  """Get the degree (number of relationships) of a node with the given label.
319
  If multiple nodes have the same label, returns the degree of the first node.
@@ -357,6 +388,41 @@ class Neo4JStorage(BaseGraphStorage):
357
  logger.error(f"Error getting node degree for {node_id}: {str(e)}")
358
  raise
359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
361
  """Get the total degree (sum of relationships) of two nodes.
362
 
@@ -376,6 +442,30 @@ class Neo4JStorage(BaseGraphStorage):
376
 
377
  degrees = int(src_degree) + int(trg_degree)
378
  return degrees
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  async def get_edge(
381
  self, source_node_id: str, target_node_id: str
@@ -463,6 +553,43 @@ class Neo4JStorage(BaseGraphStorage):
463
  )
464
  raise
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
467
  """Retrieves all edges (relationships) for a particular node identified by its label.
468
 
@@ -523,6 +650,35 @@ class Neo4JStorage(BaseGraphStorage):
523
  logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
524
  raise
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  @retry(
527
  stop=stop_after_attempt(3),
528
  wait=wait_exponential(multiplier=1, min=4, max=10),
 
314
  logger.error(f"Error getting node for {node_id}: {str(e)}")
315
  raise
316
 
317
+ async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
318
+ """
319
+ Retrieve multiple nodes in one query using UNWIND.
320
+
321
+ Args:
322
+ node_ids: List of node entity IDs to fetch.
323
+
324
+ Returns:
325
+ A dictionary mapping each node_id to its node data (or None if not found).
326
+ """
327
+ async with self._driver.session(
328
+ database=self._DATABASE, default_access_mode="READ"
329
+ ) as session:
330
+ query = """
331
+ UNWIND $node_ids AS id
332
+ MATCH (n:base {entity_id: id})
333
+ RETURN n.entity_id AS entity_id, n
334
+ """
335
+ result = await session.run(query, node_ids=node_ids)
336
+ nodes = {}
337
+ async for record in result:
338
+ entity_id = record["entity_id"]
339
+ node = record["n"]
340
+ node_dict = dict(node)
341
+ # Remove the 'base' label if present in a 'labels' property
342
+ if "labels" in node_dict:
343
+ node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
344
+ nodes[entity_id] = node_dict
345
+ await result.consume() # Make sure to consume the result fully
346
+ return nodes
347
+
348
  async def node_degree(self, node_id: str) -> int:
349
  """Get the degree (number of relationships) of a node with the given label.
350
  If multiple nodes have the same label, returns the degree of the first node.
 
388
  logger.error(f"Error getting node degree for {node_id}: {str(e)}")
389
  raise
390
 
391
+ async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
392
+ """
393
+ Retrieve the degree for multiple nodes in a single query using UNWIND.
394
+
395
+ Args:
396
+ node_ids: List of node labels (entity_id values) to look up.
397
+
398
+ Returns:
399
+ A dictionary mapping each node_id to its degree (number of relationships).
400
+ If a node is not found, its degree will be set to 0.
401
+ """
402
+ async with self._driver.session(
403
+ database=self._DATABASE, default_access_mode="READ"
404
+ ) as session:
405
+ query = """
406
+ UNWIND $node_ids AS id
407
+ MATCH (n:base {entity_id: id})
408
+ RETURN n.entity_id AS entity_id, count { (n)--() } AS degree;
409
+ """
410
+ result = await session.run(query, node_ids=node_ids)
411
+ degrees = {}
412
+ async for record in result:
413
+ entity_id = record["entity_id"]
414
+ degrees[entity_id] = record["degree"]
415
+ await result.consume() # Ensure result is fully consumed
416
+
417
+ # For any node_id that did not return a record, set degree to 0.
418
+ for nid in node_ids:
419
+ if nid not in degrees:
420
+ logger.warning(f"No node found with label '{nid}'")
421
+ degrees[nid] = 0
422
+
423
+ logger.debug(f"Neo4j batch node degree query returned: {degrees}")
424
+ return degrees
425
+
426
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
427
  """Get the total degree (sum of relationships) of two nodes.
428
 
 
442
 
443
  degrees = int(src_degree) + int(trg_degree)
444
  return degrees
445
+
446
+ async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
447
+ """
448
+ Calculate the combined degree for each edge (sum of the source and target node degrees)
449
+ in batch using the already implemented node_degrees_batch.
450
+
451
+ Args:
452
+ edge_pairs: List of (src, tgt) tuples.
453
+
454
+ Returns:
455
+ A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
456
+ """
457
+ # Collect unique node IDs from all edge pairs.
458
+ unique_node_ids = {src for src, _ in edge_pairs}
459
+ unique_node_ids.update({tgt for _, tgt in edge_pairs})
460
+
461
+ # Get degrees for all nodes in one go.
462
+ degrees = await self.node_degrees_batch(list(unique_node_ids))
463
+
464
+ # Sum up degrees for each edge pair.
465
+ edge_degrees = {}
466
+ for src, tgt in edge_pairs:
467
+ edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
468
+ return edge_degrees
469
 
470
  async def get_edge(
471
  self, source_node_id: str, target_node_id: str
 
553
  )
554
  raise
555
 
556
+ async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
557
+ """
558
+ Retrieve edge properties for multiple (src, tgt) pairs in one query.
559
+
560
+ Args:
561
+ pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
562
+
563
+ Returns:
564
+ A dictionary mapping (src, tgt) tuples to their edge properties.
565
+ """
566
+ async with self._driver.session(
567
+ database=self._DATABASE, default_access_mode="READ"
568
+ ) as session:
569
+ query = """
570
+ UNWIND $pairs AS pair
571
+ MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt})
572
+ RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
573
+ """
574
+ result = await session.run(query, pairs=pairs)
575
+ edges_dict = {}
576
+ async for record in result:
577
+ src = record["src_id"]
578
+ tgt = record["tgt_id"]
579
+ edges = record["edges"]
580
+ if edges and len(edges) > 0:
581
+ edge_props = edges[0] # choose the first if multiple exist
582
+ # Ensure required keys exist with defaults
583
+ for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
584
+ if key not in edge_props:
585
+ edge_props[key] = default
586
+ edges_dict[(src, tgt)] = edge_props
587
+ else:
588
+ # No edge found – set default edge properties
589
+ edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
590
+ await result.consume()
591
+ return edges_dict
592
+
593
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
594
  """Retrieves all edges (relationships) for a particular node identified by its label.
595
 
 
650
  logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
651
  raise
652
 
653
+ async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
654
+ """
655
+ Batch retrieve edges for multiple nodes in one query using UNWIND.
656
+
657
+ Args:
658
+ node_ids: List of node IDs (entity_id) for which to retrieve edges.
659
+
660
+ Returns:
661
+ A dictionary mapping each node ID to its list of edge tuples (source, target).
662
+ """
663
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
664
+ query = """
665
+ UNWIND $node_ids AS id
666
+ MATCH (n:base {entity_id: id})
667
+ OPTIONAL MATCH (n)-[r]-(connected:base)
668
+ RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id
669
+ """
670
+ result = await session.run(query, node_ids=node_ids)
671
+ # Initialize the dictionary with empty lists for each node ID
672
+ edges_dict = {node_id: [] for node_id in node_ids}
673
+ async for record in result:
674
+ queried_id = record["queried_id"]
675
+ source_label = record["source_entity_id"]
676
+ target_label = record["target_entity_id"]
677
+ if source_label and target_label:
678
+ edges_dict[queried_id].append((source_label, target_label))
679
+ await result.consume() # Ensure results are fully consumed
680
+ return edges_dict
681
+
682
  @retry(
683
  stop=stop_after_attempt(3),
684
  wait=wait_exponential(multiplier=1, min=4, max=10),
lightrag/operate.py CHANGED
@@ -1233,16 +1233,20 @@ async def _get_node_data(
1233
 
1234
  if not len(results):
1235
  return "", "", ""
1236
- # get entity information
1237
- node_datas, node_degrees = await asyncio.gather(
1238
- asyncio.gather(
1239
- *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
1240
- ),
1241
- asyncio.gather(
1242
- *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
1243
- ),
1244
  )
1245
 
 
 
 
 
1246
  if not all([n is not None for n in node_datas]):
1247
  logger.warning("Some nodes are missing, maybe the storage is damaged")
1248
 
@@ -1374,9 +1378,10 @@ async def _find_most_related_text_unit_from_entities(
1374
  all_one_hop_nodes.update([e[1] for e in this_edges])
1375
 
1376
  all_one_hop_nodes = list(all_one_hop_nodes)
1377
- all_one_hop_nodes_data = await asyncio.gather(
1378
- *[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
1379
- )
 
1380
 
1381
  # Add null check for node data
1382
  all_one_hop_text_units_lookup = {
@@ -1512,29 +1517,34 @@ async def _get_edge_data(
1512
  if not len(results):
1513
  return "", "", ""
1514
 
1515
- edge_datas, edge_degree = await asyncio.gather(
1516
- asyncio.gather(
1517
- *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
1518
- ),
1519
- asyncio.gather(
1520
- *[
1521
- knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
1522
- for r in results
1523
- ]
1524
- ),
1525
- )
1526
-
1527
- edge_datas = [
1528
- {
1529
- "src_id": k["src_id"],
1530
- "tgt_id": k["tgt_id"],
1531
- "rank": d,
1532
- "created_at": k.get("__created_at__", None),
1533
- **v,
1534
- }
1535
- for k, v, d in zip(results, edge_datas, edge_degree)
1536
- if v is not None
1537
- ]
 
 
 
 
 
1538
  edge_datas = sorted(
1539
  edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
1540
  )
@@ -1640,24 +1650,23 @@ async def _find_most_related_entities_from_relationships(
1640
  entity_names.append(e["tgt_id"])
1641
  seen.add(e["tgt_id"])
1642
 
1643
- node_datas, node_degrees = await asyncio.gather(
1644
- asyncio.gather(
1645
- *[
1646
- knowledge_graph_inst.get_node(entity_name)
1647
- for entity_name in entity_names
1648
- ]
1649
- ),
1650
- asyncio.gather(
1651
- *[
1652
- knowledge_graph_inst.node_degree(entity_name)
1653
- for entity_name in entity_names
1654
- ]
1655
- ),
1656
  )
1657
- node_datas = [
1658
- {**n, "entity_name": k, "rank": d}
1659
- for k, n, d in zip(entity_names, node_datas, node_degrees)
1660
- ]
 
 
 
 
 
 
 
 
1661
 
1662
  len_node_datas = len(node_datas)
1663
  node_datas = truncate_list_by_token_size(
 
1233
 
1234
  if not len(results):
1235
  return "", "", ""
1236
+
1237
+ # Extract all entity IDs from your results list
1238
+ node_ids = [r["entity_name"] for r in results]
1239
+
1240
+ # Call the batch node retrieval and degree functions concurrently.
1241
+ nodes_dict, degrees_dict = await asyncio.gather(
1242
+ knowledge_graph_inst.get_nodes_batch(node_ids),
1243
+ knowledge_graph_inst.node_degrees_batch(node_ids)
1244
  )
1245
 
1246
+ # Now, if you need the node data and degree in order:
1247
+ node_datas = [nodes_dict.get(nid) for nid in node_ids]
1248
+ node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids]
1249
+
1250
  if not all([n is not None for n in node_datas]):
1251
  logger.warning("Some nodes are missing, maybe the storage is damaged")
1252
 
 
1378
  all_one_hop_nodes.update([e[1] for e in this_edges])
1379
 
1380
  all_one_hop_nodes = list(all_one_hop_nodes)
1381
+
1382
+ # Batch retrieve one-hop node data using get_nodes_batch
1383
+ all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
1384
+ all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
1385
 
1386
  # Add null check for node data
1387
  all_one_hop_text_units_lookup = {
 
1517
  if not len(results):
1518
  return "", "", ""
1519
 
1520
+ # Prepare edge pairs in two forms:
1521
+ # For the batch edge properties function, use dicts.
1522
+ edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
1523
+ # For edge degrees, use tuples.
1524
+ edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results]
1525
+
1526
+ # Call the batched functions concurrently.
1527
+ edge_data_dict, edge_degrees_dict = await asyncio.gather(
1528
+ knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
1529
+ knowledge_graph_inst.get_edges_degree_batch(edge_pairs_tuples)
1530
+ )
1531
+
1532
+ # Reconstruct edge_datas list in the same order as results.
1533
+ edge_datas = []
1534
+ for k in results:
1535
+ pair = (k["src_id"], k["tgt_id"])
1536
+ edge_props = edge_data_dict.get(pair)
1537
+ if edge_props is not None:
1538
+ # Use edge degree from the batch as rank.
1539
+ combined = {
1540
+ "src_id": k["src_id"],
1541
+ "tgt_id": k["tgt_id"],
1542
+ "rank": edge_degrees_dict.get(pair, k.get("rank", 0)),
1543
+ "created_at": k.get("__created_at__", None),
1544
+ **edge_props,
1545
+ }
1546
+ edge_datas.append(combined)
1547
+
1548
  edge_datas = sorted(
1549
  edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
1550
  )
 
1650
  entity_names.append(e["tgt_id"])
1651
  seen.add(e["tgt_id"])
1652
 
1653
+ # Batch approach: Retrieve nodes and their degrees concurrently with one query each.
1654
+ nodes_dict, degrees_dict = await asyncio.gather(
1655
+ knowledge_graph_inst.get_nodes_batch(entity_names),
1656
+ knowledge_graph_inst.get_node_degrees_batch(entity_names)
 
 
 
 
 
 
 
 
 
1657
  )
1658
+
1659
+ # Rebuild the list in the same order as entity_names
1660
+ node_datas = []
1661
+ for entity_name in entity_names:
1662
+ node = nodes_dict.get(entity_name)
1663
+ degree = degrees_dict.get(entity_name, 0)
1664
+ if node is None:
1665
+ logger.warning(f"Node '{entity_name}' not found in batch retrieval.")
1666
+ continue
1667
+ # Combine the node data with the entity name and computed degree (as rank)
1668
+ combined = {**node, "entity_name": entity_name, "rank": degree}
1669
+ node_datas.append(combined)
1670
 
1671
  len_node_datas = len(node_datas)
1672
  node_datas = truncate_list_by_token_size(