yangdx commited on
Commit
c883d27
·
1 Parent(s): 336a783

Add node limit and prioritization for knowledge graph retrieval

Browse files

• Add MAX_GRAPH_NODES limit from env var
• Prioritize nodes by label match & connection

lightrag/kg/neo4j_impl.py CHANGED
@@ -23,7 +23,7 @@ import pipmaster as pm
23
  if not pm.is_installed("neo4j"):
24
  pm.install("neo4j")
25
 
26
- from neo4j import (
27
  AsyncGraphDatabase,
28
  exceptions as neo4jExceptions,
29
  AsyncDriver,
@@ -34,6 +34,9 @@ from neo4j import (
34
  config = configparser.ConfigParser()
35
  config.read("config.ini", "utf-8")
36
 
 
 
 
37
 
38
  @final
39
  @dataclass
@@ -471,12 +474,17 @@ class Neo4JStorage(BaseGraphStorage):
471
  ) -> KnowledgeGraph:
472
  """
473
  Get complete connected subgraph for specified node (including the starting node itself)
 
 
 
 
 
474
 
475
- Key fixes:
476
- 1. Include the starting node itself
477
- 2. Handle multi-label nodes
478
- 3. Clarify relationship directions
479
- 4. Add depth control
480
  """
481
  label = node_label.strip('"')
482
  result = KnowledgeGraph()
@@ -485,14 +493,22 @@ class Neo4JStorage(BaseGraphStorage):
485
 
486
  async with self._driver.session(database=self._DATABASE) as session:
487
  try:
488
- main_query = ""
489
  if label == "*":
490
  main_query = """
491
  MATCH (n)
492
- WITH collect(DISTINCT n) AS nodes
493
- MATCH ()-[r]-()
494
- RETURN nodes, collect(DISTINCT r) AS relationships;
 
 
 
 
 
495
  """
 
 
 
 
496
  else:
497
  # Critical debug step: first verify if starting node exists
498
  validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
@@ -512,9 +528,25 @@ class Neo4JStorage(BaseGraphStorage):
512
  bfs: true
513
  }})
514
  YIELD nodes, relationships
515
- RETURN nodes, relationships
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  """
517
- result_set = await session.run(main_query)
 
 
 
518
  record = await result_set.single()
519
 
520
  if record:
 
23
  if not pm.is_installed("neo4j"):
24
  pm.install("neo4j")
25
 
26
+ from neo4j import ( # type: ignore
27
  AsyncGraphDatabase,
28
  exceptions as neo4jExceptions,
29
  AsyncDriver,
 
34
  config = configparser.ConfigParser()
35
  config.read("config.ini", "utf-8")
36
 
37
+ # 从环境变量获取最大图节点数,默认为1000
38
+ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
39
+
40
 
41
  @final
42
  @dataclass
 
474
  ) -> KnowledgeGraph:
475
  """
476
  Get complete connected subgraph for specified node (including the starting node itself)
477
+ Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
478
+ When reducing the number of nodes, the prioritization criteria are as follows:
479
+ 1. Label matching nodes take precedence
480
+ 2. Followed by nodes directly connected to the matching nodes
481
+ 3. Finally, the degree of the nodes
482
 
483
+ Args:
484
+ node_label (str): Label of the starting node
485
+ max_depth (int, optional): Maximum depth of the graph. Defaults to 5.
486
+ Returns:
487
+ KnowledgeGraph: Complete connected subgraph for specified node
488
  """
489
  label = node_label.strip('"')
490
  result = KnowledgeGraph()
 
493
 
494
  async with self._driver.session(database=self._DATABASE) as session:
495
  try:
 
496
  if label == "*":
497
  main_query = """
498
  MATCH (n)
499
+ OPTIONAL MATCH (n)-[r]-()
500
+ WITH n, count(r) AS degree
501
+ ORDER BY degree DESC
502
+ LIMIT $max_nodes
503
+ WITH collect(n) AS nodes
504
+ MATCH (a)-[r]->(b)
505
+ WHERE a IN nodes AND b IN nodes
506
+ RETURN nodes, collect(DISTINCT r) AS relationships
507
  """
508
+ result_set = await session.run(
509
+ main_query, {"max_nodes": MAX_GRAPH_NODES}
510
+ )
511
+
512
  else:
513
  # Critical debug step: first verify if starting node exists
514
  validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
 
528
  bfs: true
529
  }})
530
  YIELD nodes, relationships
531
+ WITH start, nodes, relationships
532
+ UNWIND nodes AS node
533
+ OPTIONAL MATCH (node)-[r]-()
534
+ WITH node, count(r) AS degree, start, nodes, relationships,
535
+ CASE
536
+ WHEN id(node) = id(start) THEN 2
537
+ WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1
538
+ ELSE 0
539
+ END AS priority
540
+ ORDER BY priority DESC, degree DESC
541
+ LIMIT $max_nodes
542
+ WITH collect(node) AS filtered_nodes, nodes, relationships
543
+ RETURN filtered_nodes AS nodes,
544
+ [rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships
545
  """
546
+ result_set = await session.run(
547
+ main_query, {"max_nodes": MAX_GRAPH_NODES}
548
+ )
549
+
550
  record = await result_set.single()
551
 
552
  if record:
lightrag/kg/networkx_impl.py CHANGED
@@ -236,7 +236,11 @@ class NetworkXStorage(BaseGraphStorage):
236
  ) -> KnowledgeGraph:
237
  """
238
  Get complete connected subgraph for specified node (including the starting node itself)
239
- Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
 
 
 
 
240
 
241
  Args:
242
  node_label: Label of the starting node
@@ -268,14 +272,49 @@ class NetworkXStorage(BaseGraphStorage):
268
  logger.warning(f"No nodes found with label {node_label}")
269
  return result
270
 
271
- # Get subgraph using ego_graph
272
- subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
 
 
 
 
273
 
274
  # Check if number of nodes exceeds max_graph_nodes
275
  if len(subgraph.nodes()) > MAX_GRAPH_NODES:
276
  origin_nodes = len(subgraph.nodes())
 
 
277
  node_degrees = dict(subgraph.degree())
278
- top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  :MAX_GRAPH_NODES
280
  ]
281
  top_node_ids = [node[0] for node in top_nodes]
 
236
  ) -> KnowledgeGraph:
237
  """
238
  Get complete connected subgraph for specified node (including the starting node itself)
239
+ Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
240
+ When reducing the number of nodes, the prioritization criteria are as follows:
241
+ 1. Label matching nodes take precedence
242
+ 2. Followed by nodes directly connected to the matching nodes
243
+ 3. Finally, the degree of the nodes
244
 
245
  Args:
246
  node_label: Label of the starting node
 
272
  logger.warning(f"No nodes found with label {node_label}")
273
  return result
274
 
275
+ # Get subgraph using ego_graph from all matching nodes
276
+ combined_subgraph = nx.Graph()
277
+ for start_node in nodes_to_explore:
278
+ node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
279
+ combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
280
+ subgraph = combined_subgraph
281
 
282
  # Check if number of nodes exceeds max_graph_nodes
283
  if len(subgraph.nodes()) > MAX_GRAPH_NODES:
284
  origin_nodes = len(subgraph.nodes())
285
+
286
+ # 获取节点度数
287
  node_degrees = dict(subgraph.degree())
288
+
289
+ # 标记起点节点和直接连接的节点
290
+ start_nodes = set()
291
+ direct_connected_nodes = set()
292
+
293
+ if node_label != "*" and nodes_to_explore:
294
+ # 所有在 nodes_to_explore 中的节点都是起点节点
295
+ start_nodes = set(nodes_to_explore)
296
+
297
+ # 获取与所有起点直接连接的节点
298
+ for start_node in start_nodes:
299
+ direct_connected_nodes.update(subgraph.neighbors(start_node))
300
+
301
+ # 从直接连接节点中移除起点节点(避免重复)
302
+ direct_connected_nodes -= start_nodes
303
+
304
+ # 按优先级和度数排序
305
+ def priority_key(node_item):
306
+ node, degree = node_item
307
+ # 优先级排序:起点(2) > 直接连接(1) > 其他节点(0)
308
+ if node in start_nodes:
309
+ priority = 2
310
+ elif node in direct_connected_nodes:
311
+ priority = 1
312
+ else:
313
+ priority = 0
314
+ return (priority, degree) # 先按优先级,再按度数
315
+
316
+ # 排序并选择前MAX_GRAPH_NODES个节点
317
+ top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
318
  :MAX_GRAPH_NODES
319
  ]
320
  top_node_ids = [node[0] for node in top_nodes]