frederikhendrix
commited on
Commit
·
f43436e
1
Parent(s):
a450594
get_node added and all to base.py and to neo4j_impl.py file
Browse files- lightrag/base.py +20 -0
- lightrag/kg/neo4j_impl.py +156 -0
- lightrag/operate.py +60 -51
lightrag/base.py
CHANGED
@@ -309,6 +309,26 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
309 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
310 |
"""Upsert a node into the graph."""
|
311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
@abstractmethod
|
313 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
314 |
"""Upsert an edge into the graph."""
|
|
|
309 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
310 |
"""Upsert a node into the graph."""
|
311 |
|
312 |
+
@abstractmethod
|
313 |
+
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
314 |
+
"""Get nodes as a batch using UNWIND"""
|
315 |
+
|
316 |
+
@abstractmethod
|
317 |
+
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
318 |
+
"""Node degrees as a batch using UNWIND"""
|
319 |
+
|
320 |
+
@abstractmethod
|
321 |
+
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
322 |
+
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch"""
|
323 |
+
|
324 |
+
@abstractmethod
|
325 |
+
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
326 |
+
"""Get edges as a batch using UNWIND"""
|
327 |
+
|
328 |
+
@abstractmethod
|
329 |
+
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
330 |
+
""""Get nodes edges as a batch using UNWIND"""
|
331 |
+
|
332 |
@abstractmethod
|
333 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
334 |
"""Upsert an edge into the graph."""
|
lightrag/kg/neo4j_impl.py
CHANGED
@@ -314,6 +314,37 @@ class Neo4JStorage(BaseGraphStorage):
|
|
314 |
logger.error(f"Error getting node for {node_id}: {str(e)}")
|
315 |
raise
|
316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
async def node_degree(self, node_id: str) -> int:
|
318 |
"""Get the degree (number of relationships) of a node with the given label.
|
319 |
If multiple nodes have the same label, returns the degree of the first node.
|
@@ -357,6 +388,41 @@ class Neo4JStorage(BaseGraphStorage):
|
|
357 |
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
|
358 |
raise
|
359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
361 |
"""Get the total degree (sum of relationships) of two nodes.
|
362 |
|
@@ -376,6 +442,30 @@ class Neo4JStorage(BaseGraphStorage):
|
|
376 |
|
377 |
degrees = int(src_degree) + int(trg_degree)
|
378 |
return degrees
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
async def get_edge(
|
381 |
self, source_node_id: str, target_node_id: str
|
@@ -463,6 +553,43 @@ class Neo4JStorage(BaseGraphStorage):
|
|
463 |
)
|
464 |
raise
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
467 |
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
468 |
|
@@ -523,6 +650,35 @@ class Neo4JStorage(BaseGraphStorage):
|
|
523 |
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
524 |
raise
|
525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
@retry(
|
527 |
stop=stop_after_attempt(3),
|
528 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
|
314 |
logger.error(f"Error getting node for {node_id}: {str(e)}")
|
315 |
raise
|
316 |
|
317 |
+
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
318 |
+
"""
|
319 |
+
Retrieve multiple nodes in one query using UNWIND.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
node_ids: List of node entity IDs to fetch.
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
A dictionary mapping each node_id to its node data (or None if not found).
|
326 |
+
"""
|
327 |
+
async with self._driver.session(
|
328 |
+
database=self._DATABASE, default_access_mode="READ"
|
329 |
+
) as session:
|
330 |
+
query = """
|
331 |
+
UNWIND $node_ids AS id
|
332 |
+
MATCH (n:base {entity_id: id})
|
333 |
+
RETURN n.entity_id AS entity_id, n
|
334 |
+
"""
|
335 |
+
result = await session.run(query, node_ids=node_ids)
|
336 |
+
nodes = {}
|
337 |
+
async for record in result:
|
338 |
+
entity_id = record["entity_id"]
|
339 |
+
node = record["n"]
|
340 |
+
node_dict = dict(node)
|
341 |
+
# Remove the 'base' label if present in a 'labels' property
|
342 |
+
if "labels" in node_dict:
|
343 |
+
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
344 |
+
nodes[entity_id] = node_dict
|
345 |
+
await result.consume() # Make sure to consume the result fully
|
346 |
+
return nodes
|
347 |
+
|
348 |
async def node_degree(self, node_id: str) -> int:
|
349 |
"""Get the degree (number of relationships) of a node with the given label.
|
350 |
If multiple nodes have the same label, returns the degree of the first node.
|
|
|
388 |
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
|
389 |
raise
|
390 |
|
391 |
+
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
392 |
+
"""
|
393 |
+
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
node_ids: List of node labels (entity_id values) to look up.
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
A dictionary mapping each node_id to its degree (number of relationships).
|
400 |
+
If a node is not found, its degree will be set to 0.
|
401 |
+
"""
|
402 |
+
async with self._driver.session(
|
403 |
+
database=self._DATABASE, default_access_mode="READ"
|
404 |
+
) as session:
|
405 |
+
query = """
|
406 |
+
UNWIND $node_ids AS id
|
407 |
+
MATCH (n:base {entity_id: id})
|
408 |
+
RETURN n.entity_id AS entity_id, count { (n)--() } AS degree;
|
409 |
+
"""
|
410 |
+
result = await session.run(query, node_ids=node_ids)
|
411 |
+
degrees = {}
|
412 |
+
async for record in result:
|
413 |
+
entity_id = record["entity_id"]
|
414 |
+
degrees[entity_id] = record["degree"]
|
415 |
+
await result.consume() # Ensure result is fully consumed
|
416 |
+
|
417 |
+
# For any node_id that did not return a record, set degree to 0.
|
418 |
+
for nid in node_ids:
|
419 |
+
if nid not in degrees:
|
420 |
+
logger.warning(f"No node found with label '{nid}'")
|
421 |
+
degrees[nid] = 0
|
422 |
+
|
423 |
+
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
|
424 |
+
return degrees
|
425 |
+
|
426 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
427 |
"""Get the total degree (sum of relationships) of two nodes.
|
428 |
|
|
|
442 |
|
443 |
degrees = int(src_degree) + int(trg_degree)
|
444 |
return degrees
|
445 |
+
|
446 |
+
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
447 |
+
"""
|
448 |
+
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
449 |
+
in batch using the already implemented node_degrees_batch.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
edge_pairs: List of (src, tgt) tuples.
|
453 |
+
|
454 |
+
Returns:
|
455 |
+
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
|
456 |
+
"""
|
457 |
+
# Collect unique node IDs from all edge pairs.
|
458 |
+
unique_node_ids = {src for src, _ in edge_pairs}
|
459 |
+
unique_node_ids.update({tgt for _, tgt in edge_pairs})
|
460 |
+
|
461 |
+
# Get degrees for all nodes in one go.
|
462 |
+
degrees = await self.node_degrees_batch(list(unique_node_ids))
|
463 |
+
|
464 |
+
# Sum up degrees for each edge pair.
|
465 |
+
edge_degrees = {}
|
466 |
+
for src, tgt in edge_pairs:
|
467 |
+
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
|
468 |
+
return edge_degrees
|
469 |
|
470 |
async def get_edge(
|
471 |
self, source_node_id: str, target_node_id: str
|
|
|
553 |
)
|
554 |
raise
|
555 |
|
556 |
+
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
557 |
+
"""
|
558 |
+
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
559 |
+
|
560 |
+
Args:
|
561 |
+
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
562 |
+
|
563 |
+
Returns:
|
564 |
+
A dictionary mapping (src, tgt) tuples to their edge properties.
|
565 |
+
"""
|
566 |
+
async with self._driver.session(
|
567 |
+
database=self._DATABASE, default_access_mode="READ"
|
568 |
+
) as session:
|
569 |
+
query = """
|
570 |
+
UNWIND $pairs AS pair
|
571 |
+
MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt})
|
572 |
+
RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
|
573 |
+
"""
|
574 |
+
result = await session.run(query, pairs=pairs)
|
575 |
+
edges_dict = {}
|
576 |
+
async for record in result:
|
577 |
+
src = record["src_id"]
|
578 |
+
tgt = record["tgt_id"]
|
579 |
+
edges = record["edges"]
|
580 |
+
if edges and len(edges) > 0:
|
581 |
+
edge_props = edges[0] # choose the first if multiple exist
|
582 |
+
# Ensure required keys exist with defaults
|
583 |
+
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
|
584 |
+
if key not in edge_props:
|
585 |
+
edge_props[key] = default
|
586 |
+
edges_dict[(src, tgt)] = edge_props
|
587 |
+
else:
|
588 |
+
# No edge found – set default edge properties
|
589 |
+
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
|
590 |
+
await result.consume()
|
591 |
+
return edges_dict
|
592 |
+
|
593 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
594 |
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
595 |
|
|
|
650 |
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
651 |
raise
|
652 |
|
653 |
+
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
654 |
+
"""
|
655 |
+
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
656 |
+
|
657 |
+
Args:
|
658 |
+
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
659 |
+
|
660 |
+
Returns:
|
661 |
+
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
662 |
+
"""
|
663 |
+
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
664 |
+
query = """
|
665 |
+
UNWIND $node_ids AS id
|
666 |
+
MATCH (n:base {entity_id: id})
|
667 |
+
OPTIONAL MATCH (n)-[r]-(connected:base)
|
668 |
+
RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id
|
669 |
+
"""
|
670 |
+
result = await session.run(query, node_ids=node_ids)
|
671 |
+
# Initialize the dictionary with empty lists for each node ID
|
672 |
+
edges_dict = {node_id: [] for node_id in node_ids}
|
673 |
+
async for record in result:
|
674 |
+
queried_id = record["queried_id"]
|
675 |
+
source_label = record["source_entity_id"]
|
676 |
+
target_label = record["target_entity_id"]
|
677 |
+
if source_label and target_label:
|
678 |
+
edges_dict[queried_id].append((source_label, target_label))
|
679 |
+
await result.consume() # Ensure results are fully consumed
|
680 |
+
return edges_dict
|
681 |
+
|
682 |
@retry(
|
683 |
stop=stop_after_attempt(3),
|
684 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
lightrag/operate.py
CHANGED
@@ -1233,16 +1233,20 @@ async def _get_node_data(
|
|
1233 |
|
1234 |
if not len(results):
|
1235 |
return "", "", ""
|
1236 |
-
|
1237 |
-
|
1238 |
-
|
1239 |
-
|
1240 |
-
|
1241 |
-
|
1242 |
-
|
1243 |
-
)
|
1244 |
)
|
1245 |
|
|
|
|
|
|
|
|
|
1246 |
if not all([n is not None for n in node_datas]):
|
1247 |
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
1248 |
|
@@ -1374,9 +1378,10 @@ async def _find_most_related_text_unit_from_entities(
|
|
1374 |
all_one_hop_nodes.update([e[1] for e in this_edges])
|
1375 |
|
1376 |
all_one_hop_nodes = list(all_one_hop_nodes)
|
1377 |
-
|
1378 |
-
|
1379 |
-
)
|
|
|
1380 |
|
1381 |
# Add null check for node data
|
1382 |
all_one_hop_text_units_lookup = {
|
@@ -1512,29 +1517,34 @@ async def _get_edge_data(
|
|
1512 |
if not len(results):
|
1513 |
return "", "", ""
|
1514 |
|
1515 |
-
|
1516 |
-
|
1517 |
-
|
1518 |
-
|
1519 |
-
|
1520 |
-
|
1521 |
-
|
1522 |
-
|
1523 |
-
|
1524 |
-
)
|
1525 |
-
)
|
1526 |
-
|
1527 |
-
edge_datas
|
1528 |
-
|
1529 |
-
|
1530 |
-
|
1531 |
-
|
1532 |
-
|
1533 |
-
|
1534 |
-
|
1535 |
-
|
1536 |
-
|
1537 |
-
|
|
|
|
|
|
|
|
|
|
|
1538 |
edge_datas = sorted(
|
1539 |
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1540 |
)
|
@@ -1640,24 +1650,23 @@ async def _find_most_related_entities_from_relationships(
|
|
1640 |
entity_names.append(e["tgt_id"])
|
1641 |
seen.add(e["tgt_id"])
|
1642 |
|
1643 |
-
|
1644 |
-
|
1645 |
-
|
1646 |
-
|
1647 |
-
for entity_name in entity_names
|
1648 |
-
]
|
1649 |
-
),
|
1650 |
-
asyncio.gather(
|
1651 |
-
*[
|
1652 |
-
knowledge_graph_inst.node_degree(entity_name)
|
1653 |
-
for entity_name in entity_names
|
1654 |
-
]
|
1655 |
-
),
|
1656 |
)
|
1657 |
-
|
1658 |
-
|
1659 |
-
|
1660 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1661 |
|
1662 |
len_node_datas = len(node_datas)
|
1663 |
node_datas = truncate_list_by_token_size(
|
|
|
1233 |
|
1234 |
if not len(results):
|
1235 |
return "", "", ""
|
1236 |
+
|
1237 |
+
# Extract all entity IDs from your results list
|
1238 |
+
node_ids = [r["entity_name"] for r in results]
|
1239 |
+
|
1240 |
+
# Call the batch node retrieval and degree functions concurrently.
|
1241 |
+
nodes_dict, degrees_dict = await asyncio.gather(
|
1242 |
+
knowledge_graph_inst.get_nodes_batch(node_ids),
|
1243 |
+
knowledge_graph_inst.node_degrees_batch(node_ids)
|
1244 |
)
|
1245 |
|
1246 |
+
# Now, if you need the node data and degree in order:
|
1247 |
+
node_datas = [nodes_dict.get(nid) for nid in node_ids]
|
1248 |
+
node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids]
|
1249 |
+
|
1250 |
if not all([n is not None for n in node_datas]):
|
1251 |
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
1252 |
|
|
|
1378 |
all_one_hop_nodes.update([e[1] for e in this_edges])
|
1379 |
|
1380 |
all_one_hop_nodes = list(all_one_hop_nodes)
|
1381 |
+
|
1382 |
+
# Batch retrieve one-hop node data using get_nodes_batch
|
1383 |
+
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
|
1384 |
+
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
|
1385 |
|
1386 |
# Add null check for node data
|
1387 |
all_one_hop_text_units_lookup = {
|
|
|
1517 |
if not len(results):
|
1518 |
return "", "", ""
|
1519 |
|
1520 |
+
# Prepare edge pairs in two forms:
|
1521 |
+
# For the batch edge properties function, use dicts.
|
1522 |
+
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
|
1523 |
+
# For edge degrees, use tuples.
|
1524 |
+
edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results]
|
1525 |
+
|
1526 |
+
# Call the batched functions concurrently.
|
1527 |
+
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
1528 |
+
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
1529 |
+
knowledge_graph_inst.get_edges_degree_batch(edge_pairs_tuples)
|
1530 |
+
)
|
1531 |
+
|
1532 |
+
# Reconstruct edge_datas list in the same order as results.
|
1533 |
+
edge_datas = []
|
1534 |
+
for k in results:
|
1535 |
+
pair = (k["src_id"], k["tgt_id"])
|
1536 |
+
edge_props = edge_data_dict.get(pair)
|
1537 |
+
if edge_props is not None:
|
1538 |
+
# Use edge degree from the batch as rank.
|
1539 |
+
combined = {
|
1540 |
+
"src_id": k["src_id"],
|
1541 |
+
"tgt_id": k["tgt_id"],
|
1542 |
+
"rank": edge_degrees_dict.get(pair, k.get("rank", 0)),
|
1543 |
+
"created_at": k.get("__created_at__", None),
|
1544 |
+
**edge_props,
|
1545 |
+
}
|
1546 |
+
edge_datas.append(combined)
|
1547 |
+
|
1548 |
edge_datas = sorted(
|
1549 |
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1550 |
)
|
|
|
1650 |
entity_names.append(e["tgt_id"])
|
1651 |
seen.add(e["tgt_id"])
|
1652 |
|
1653 |
+
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
|
1654 |
+
nodes_dict, degrees_dict = await asyncio.gather(
|
1655 |
+
knowledge_graph_inst.get_nodes_batch(entity_names),
|
1656 |
+
knowledge_graph_inst.get_node_degrees_batch(entity_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1657 |
)
|
1658 |
+
|
1659 |
+
# Rebuild the list in the same order as entity_names
|
1660 |
+
node_datas = []
|
1661 |
+
for entity_name in entity_names:
|
1662 |
+
node = nodes_dict.get(entity_name)
|
1663 |
+
degree = degrees_dict.get(entity_name, 0)
|
1664 |
+
if node is None:
|
1665 |
+
logger.warning(f"Node '{entity_name}' not found in batch retrieval.")
|
1666 |
+
continue
|
1667 |
+
# Combine the node data with the entity name and computed degree (as rank)
|
1668 |
+
combined = {**node, "entity_name": entity_name, "rank": degree}
|
1669 |
+
node_datas.append(combined)
|
1670 |
|
1671 |
len_node_datas = len(node_datas)
|
1672 |
node_datas = truncate_list_by_token_size(
|