yangdx
commited on
Commit
·
c883d27
1
Parent(s):
336a783
Add node limit and prioritization for knowledge graph retrieval
Browse files• Add MAX_GRAPH_NODES limit from env var
• Prioritize nodes by label match & connection
- lightrag/kg/neo4j_impl.py +44 -12
- lightrag/kg/networkx_impl.py +43 -4
lightrag/kg/neo4j_impl.py
CHANGED
@@ -23,7 +23,7 @@ import pipmaster as pm
|
|
23 |
if not pm.is_installed("neo4j"):
|
24 |
pm.install("neo4j")
|
25 |
|
26 |
-
from neo4j import (
|
27 |
AsyncGraphDatabase,
|
28 |
exceptions as neo4jExceptions,
|
29 |
AsyncDriver,
|
@@ -34,6 +34,9 @@ from neo4j import (
|
|
34 |
config = configparser.ConfigParser()
|
35 |
config.read("config.ini", "utf-8")
|
36 |
|
|
|
|
|
|
|
37 |
|
38 |
@final
|
39 |
@dataclass
|
@@ -471,12 +474,17 @@ class Neo4JStorage(BaseGraphStorage):
|
|
471 |
) -> KnowledgeGraph:
|
472 |
"""
|
473 |
Get complete connected subgraph for specified node (including the starting node itself)
|
|
|
|
|
|
|
|
|
|
|
474 |
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
"""
|
481 |
label = node_label.strip('"')
|
482 |
result = KnowledgeGraph()
|
@@ -485,14 +493,22 @@ class Neo4JStorage(BaseGraphStorage):
|
|
485 |
|
486 |
async with self._driver.session(database=self._DATABASE) as session:
|
487 |
try:
|
488 |
-
main_query = ""
|
489 |
if label == "*":
|
490 |
main_query = """
|
491 |
MATCH (n)
|
492 |
-
|
493 |
-
|
494 |
-
|
|
|
|
|
|
|
|
|
|
|
495 |
"""
|
|
|
|
|
|
|
|
|
496 |
else:
|
497 |
# Critical debug step: first verify if starting node exists
|
498 |
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
|
@@ -512,9 +528,25 @@ class Neo4JStorage(BaseGraphStorage):
|
|
512 |
bfs: true
|
513 |
}})
|
514 |
YIELD nodes, relationships
|
515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
"""
|
517 |
-
|
|
|
|
|
|
|
518 |
record = await result_set.single()
|
519 |
|
520 |
if record:
|
|
|
23 |
if not pm.is_installed("neo4j"):
|
24 |
pm.install("neo4j")
|
25 |
|
26 |
+
from neo4j import ( # type: ignore
|
27 |
AsyncGraphDatabase,
|
28 |
exceptions as neo4jExceptions,
|
29 |
AsyncDriver,
|
|
|
34 |
config = configparser.ConfigParser()
|
35 |
config.read("config.ini", "utf-8")
|
36 |
|
37 |
+
# 从环境变量获取最大图节点数,默认为1000
|
38 |
+
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
39 |
+
|
40 |
|
41 |
@final
|
42 |
@dataclass
|
|
|
474 |
) -> KnowledgeGraph:
|
475 |
"""
|
476 |
Get complete connected subgraph for specified node (including the starting node itself)
|
477 |
+
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
478 |
+
When reducing the number of nodes, the prioritization criteria are as follows:
|
479 |
+
1. Label matching nodes take precedence
|
480 |
+
2. Followed by nodes directly connected to the matching nodes
|
481 |
+
3. Finally, the degree of the nodes
|
482 |
|
483 |
+
Args:
|
484 |
+
node_label (str): Label of the starting node
|
485 |
+
max_depth (int, optional): Maximum depth of the graph. Defaults to 5.
|
486 |
+
Returns:
|
487 |
+
KnowledgeGraph: Complete connected subgraph for specified node
|
488 |
"""
|
489 |
label = node_label.strip('"')
|
490 |
result = KnowledgeGraph()
|
|
|
493 |
|
494 |
async with self._driver.session(database=self._DATABASE) as session:
|
495 |
try:
|
|
|
496 |
if label == "*":
|
497 |
main_query = """
|
498 |
MATCH (n)
|
499 |
+
OPTIONAL MATCH (n)-[r]-()
|
500 |
+
WITH n, count(r) AS degree
|
501 |
+
ORDER BY degree DESC
|
502 |
+
LIMIT $max_nodes
|
503 |
+
WITH collect(n) AS nodes
|
504 |
+
MATCH (a)-[r]->(b)
|
505 |
+
WHERE a IN nodes AND b IN nodes
|
506 |
+
RETURN nodes, collect(DISTINCT r) AS relationships
|
507 |
"""
|
508 |
+
result_set = await session.run(
|
509 |
+
main_query, {"max_nodes": MAX_GRAPH_NODES}
|
510 |
+
)
|
511 |
+
|
512 |
else:
|
513 |
# Critical debug step: first verify if starting node exists
|
514 |
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
|
|
|
528 |
bfs: true
|
529 |
}})
|
530 |
YIELD nodes, relationships
|
531 |
+
WITH start, nodes, relationships
|
532 |
+
UNWIND nodes AS node
|
533 |
+
OPTIONAL MATCH (node)-[r]-()
|
534 |
+
WITH node, count(r) AS degree, start, nodes, relationships,
|
535 |
+
CASE
|
536 |
+
WHEN id(node) = id(start) THEN 2
|
537 |
+
WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1
|
538 |
+
ELSE 0
|
539 |
+
END AS priority
|
540 |
+
ORDER BY priority DESC, degree DESC
|
541 |
+
LIMIT $max_nodes
|
542 |
+
WITH collect(node) AS filtered_nodes, nodes, relationships
|
543 |
+
RETURN filtered_nodes AS nodes,
|
544 |
+
[rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships
|
545 |
"""
|
546 |
+
result_set = await session.run(
|
547 |
+
main_query, {"max_nodes": MAX_GRAPH_NODES}
|
548 |
+
)
|
549 |
+
|
550 |
record = await result_set.single()
|
551 |
|
552 |
if record:
|
lightrag/kg/networkx_impl.py
CHANGED
@@ -236,7 +236,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|
236 |
) -> KnowledgeGraph:
|
237 |
"""
|
238 |
Get complete connected subgraph for specified node (including the starting node itself)
|
239 |
-
Maximum number of nodes is
|
|
|
|
|
|
|
|
|
240 |
|
241 |
Args:
|
242 |
node_label: Label of the starting node
|
@@ -268,14 +272,49 @@ class NetworkXStorage(BaseGraphStorage):
|
|
268 |
logger.warning(f"No nodes found with label {node_label}")
|
269 |
return result
|
270 |
|
271 |
-
# Get subgraph using ego_graph
|
272 |
-
|
|
|
|
|
|
|
|
|
273 |
|
274 |
# Check if number of nodes exceeds max_graph_nodes
|
275 |
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
276 |
origin_nodes = len(subgraph.nodes())
|
|
|
|
|
277 |
node_degrees = dict(subgraph.degree())
|
278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
:MAX_GRAPH_NODES
|
280 |
]
|
281 |
top_node_ids = [node[0] for node in top_nodes]
|
|
|
236 |
) -> KnowledgeGraph:
|
237 |
"""
|
238 |
Get complete connected subgraph for specified node (including the starting node itself)
|
239 |
+
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
240 |
+
When reducing the number of nodes, the prioritization criteria are as follows:
|
241 |
+
1. Label matching nodes take precedence
|
242 |
+
2. Followed by nodes directly connected to the matching nodes
|
243 |
+
3. Finally, the degree of the nodes
|
244 |
|
245 |
Args:
|
246 |
node_label: Label of the starting node
|
|
|
272 |
logger.warning(f"No nodes found with label {node_label}")
|
273 |
return result
|
274 |
|
275 |
+
# Get subgraph using ego_graph from all matching nodes
|
276 |
+
combined_subgraph = nx.Graph()
|
277 |
+
for start_node in nodes_to_explore:
|
278 |
+
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
|
279 |
+
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
280 |
+
subgraph = combined_subgraph
|
281 |
|
282 |
# Check if number of nodes exceeds max_graph_nodes
|
283 |
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
284 |
origin_nodes = len(subgraph.nodes())
|
285 |
+
|
286 |
+
# 获取节点度数
|
287 |
node_degrees = dict(subgraph.degree())
|
288 |
+
|
289 |
+
# 标记起点节点和直接连接的节点
|
290 |
+
start_nodes = set()
|
291 |
+
direct_connected_nodes = set()
|
292 |
+
|
293 |
+
if node_label != "*" and nodes_to_explore:
|
294 |
+
# 所有在 nodes_to_explore 中的节点都是起点节点
|
295 |
+
start_nodes = set(nodes_to_explore)
|
296 |
+
|
297 |
+
# 获取与所有起点直接连接的节点
|
298 |
+
for start_node in start_nodes:
|
299 |
+
direct_connected_nodes.update(subgraph.neighbors(start_node))
|
300 |
+
|
301 |
+
# 从直接连接节点中移除起点节点(避免重复)
|
302 |
+
direct_connected_nodes -= start_nodes
|
303 |
+
|
304 |
+
# 按优先级和度数排序
|
305 |
+
def priority_key(node_item):
|
306 |
+
node, degree = node_item
|
307 |
+
# 优先级排序:起点(2) > 直接连接(1) > 其他节点(0)
|
308 |
+
if node in start_nodes:
|
309 |
+
priority = 2
|
310 |
+
elif node in direct_connected_nodes:
|
311 |
+
priority = 1
|
312 |
+
else:
|
313 |
+
priority = 0
|
314 |
+
return (priority, degree) # 先按优先级,再按度数
|
315 |
+
|
316 |
+
# 排序并选择前MAX_GRAPH_NODES个节点
|
317 |
+
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
|
318 |
:MAX_GRAPH_NODES
|
319 |
]
|
320 |
top_node_ids = [node[0] for node in top_nodes]
|