yangdx commited on
Commit
effc66c
·
1 Parent(s): cd3a8eb

Fix subgraph filtering bugs

Browse files
Files changed (1) hide show
  1. lightrag/kg/networkx_impl.py +16 -18
lightrag/kg/networkx_impl.py CHANGED
@@ -263,12 +263,14 @@ class NetworkXStorage(BaseGraphStorage):
263
 
264
  graph = await self._get_graph()
265
 
 
 
 
 
266
  # Handle special case for "*" label
267
  if node_label == "*":
268
  # For "*", return the entire graph including all nodes and edges
269
- subgraph = (
270
- graph.copy()
271
- ) # Create a copy to avoid modifying the original graph
272
  else:
273
  # Find nodes with matching node id based on search_mode
274
  nodes_to_explore = []
@@ -292,10 +294,7 @@ class NetworkXStorage(BaseGraphStorage):
292
  combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
293
 
294
  # Get start nodes and direct connected nodes
295
- start_nodes = set()
296
- direct_connected_nodes = set()
297
-
298
- if node_label != "*" and nodes_to_explore:
299
  start_nodes = set(nodes_to_explore)
300
  # Get nodes directly connected to all start nodes
301
  for start_node in start_nodes:
@@ -306,19 +305,18 @@ class NetworkXStorage(BaseGraphStorage):
306
  # Remove start nodes from directly connected nodes (avoid duplicates)
307
  direct_connected_nodes -= start_nodes
308
 
309
- # Filter nodes based on min_degree, but keep start nodes and direct connected nodes
310
- if min_degree > 0:
311
- nodes_to_keep = [
312
- node
313
- for node, degree in combined_subgraph.degree()
314
- if node in start_nodes
315
- or node in direct_connected_nodes
316
- or degree >= min_degree
317
- ]
318
- combined_subgraph = combined_subgraph.subgraph(nodes_to_keep)
319
-
320
  subgraph = combined_subgraph
321
 
 
 
 
 
 
 
 
 
 
 
322
  # Check if number of nodes exceeds max_graph_nodes
323
  if len(subgraph.nodes()) > MAX_GRAPH_NODES:
324
  origin_nodes = len(subgraph.nodes())
 
263
 
264
  graph = await self._get_graph()
265
 
266
+ # Initialize sets for start nodes and direct connected nodes
267
+ start_nodes = set()
268
+ direct_connected_nodes = set()
269
+
270
  # Handle special case for "*" label
271
  if node_label == "*":
272
  # For "*", return the entire graph including all nodes and edges
273
+ subgraph = graph.copy() # Create a copy to avoid modifying the original graph
 
 
274
  else:
275
  # Find nodes with matching node id based on search_mode
276
  nodes_to_explore = []
 
294
  combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
295
 
296
  # Get start nodes and direct connected nodes
297
+ if nodes_to_explore:
 
 
 
298
  start_nodes = set(nodes_to_explore)
299
  # Get nodes directly connected to all start nodes
300
  for start_node in start_nodes:
 
305
  # Remove start nodes from directly connected nodes (avoid duplicates)
306
  direct_connected_nodes -= start_nodes
307
 
 
 
 
 
 
 
 
 
 
 
 
308
  subgraph = combined_subgraph
309
 
310
+ # Filter nodes based on min_degree, but keep start nodes and direct connected nodes
311
+ if min_degree > 0:
312
+ nodes_to_keep = [
313
+ node
314
+ for node, degree in subgraph.degree()
315
+ if (node_label != "*" and (node in start_nodes or node in direct_connected_nodes))
316
+ or degree >= min_degree
317
+ ]
318
+ subgraph = subgraph.subgraph(nodes_to_keep)
319
+
320
  # Check if number of nodes exceeds max_graph_nodes
321
  if len(subgraph.nodes()) > MAX_GRAPH_NODES:
322
  origin_nodes = len(subgraph.nodes())