yangdx commited on
Commit
016d00c
·
1 Parent(s): 61ad9a7

Refactor Neo4J graph query with min_degree an inclusive match support

Browse files
Files changed (1) hide show
  1. lightrag/kg/neo4j_impl.py +278 -162
lightrag/kg/neo4j_impl.py CHANGED
@@ -41,6 +41,7 @@ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
41
  # Set neo4j logger level to ERROR to suppress warning logs
42
  logging.getLogger("neo4j").setLevel(logging.ERROR)
43
 
 
44
  @final
45
  @dataclass
46
  class Neo4JStorage(BaseGraphStorage):
@@ -63,19 +64,25 @@ class Neo4JStorage(BaseGraphStorage):
63
  MAX_CONNECTION_POOL_SIZE = int(
64
  os.environ.get(
65
  "NEO4J_MAX_CONNECTION_POOL_SIZE",
66
- config.get("neo4j", "connection_pool_size", fallback=800),
67
  )
68
  )
69
  CONNECTION_TIMEOUT = float(
70
  os.environ.get(
71
  "NEO4J_CONNECTION_TIMEOUT",
72
- config.get("neo4j", "connection_timeout", fallback=60.0),
73
  ),
74
  )
75
  CONNECTION_ACQUISITION_TIMEOUT = float(
76
  os.environ.get(
77
  "NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
78
- config.get("neo4j", "connection_acquisition_timeout", fallback=60.0),
 
 
 
 
 
 
79
  ),
80
  )
81
  DATABASE = os.environ.get(
@@ -88,6 +95,7 @@ class Neo4JStorage(BaseGraphStorage):
88
  max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
89
  connection_timeout=CONNECTION_TIMEOUT,
90
  connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
 
91
  )
92
 
93
  # Try to connect to the database
@@ -169,21 +177,24 @@ class Neo4JStorage(BaseGraphStorage):
169
 
170
  async def _ensure_label(self, label: str) -> str:
171
  """Ensure a label is valid
172
-
173
  Args:
174
  label: The label to validate
175
  """
176
  clean_label = label.strip('"')
 
 
177
  return clean_label
178
 
179
  async def has_node(self, node_id: str) -> bool:
180
  entity_name_label = await self._ensure_label(node_id)
181
- async with self._driver.session(database=self._DATABASE) as session:
182
  query = (
183
  f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
184
  )
185
  result = await session.run(query)
186
  single_result = await result.single()
 
187
  logger.debug(
188
  f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
189
  )
@@ -193,13 +204,14 @@ class Neo4JStorage(BaseGraphStorage):
193
  entity_name_label_source = source_node_id.strip('"')
194
  entity_name_label_target = target_node_id.strip('"')
195
 
196
- async with self._driver.session(database=self._DATABASE) as session:
197
  query = (
198
  f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
199
  "RETURN COUNT(r) > 0 AS edgeExists"
200
  )
201
  result = await session.run(query)
202
  single_result = await result.single()
 
203
  logger.debug(
204
  f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
205
  )
@@ -215,13 +227,16 @@ class Neo4JStorage(BaseGraphStorage):
215
  dict: Node properties if found
216
  None: If node not found
217
  """
218
- async with self._driver.session(database=self._DATABASE) as session:
219
  entity_name_label = await self._ensure_label(node_id)
220
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
221
  result = await session.run(query)
222
- record = await result.single()
223
- if record:
224
- node = record["n"]
 
 
 
225
  node_dict = dict(node)
226
  logger.debug(
227
  f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
@@ -230,23 +245,40 @@ class Neo4JStorage(BaseGraphStorage):
230
  return None
231
 
232
  async def node_degree(self, node_id: str) -> int:
 
 
 
 
 
 
 
 
 
 
233
  entity_name_label = node_id.strip('"')
234
 
235
- async with self._driver.session(database=self._DATABASE) as session:
236
  query = f"""
237
  MATCH (n:`{entity_name_label}`)
238
- RETURN COUNT{{ (n)--() }} AS totalEdgeCount
 
239
  """
240
  result = await session.run(query)
241
- record = await result.single()
242
- if record:
243
- edge_count = record["totalEdgeCount"]
244
- logger.debug(
245
- f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
246
- )
247
- return edge_count
248
- else:
249
- return None
 
 
 
 
 
 
250
 
251
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
252
  entity_name_label_source = src_id.strip('"')
@@ -264,6 +296,31 @@ class Neo4JStorage(BaseGraphStorage):
264
  )
265
  return degrees
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  async def get_edge(
268
  self, source_node_id: str, target_node_id: str
269
  ) -> dict[str, str] | None:
@@ -271,18 +328,21 @@ class Neo4JStorage(BaseGraphStorage):
271
  entity_name_label_source = source_node_id.strip('"')
272
  entity_name_label_target = target_node_id.strip('"')
273
 
274
- async with self._driver.session(database=self._DATABASE) as session:
275
  query = f"""
276
- MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
277
  RETURN properties(r) as edge_properties
278
- LIMIT 1
279
  """
280
 
281
  result = await session.run(query)
282
- record = await result.single()
283
- if record:
 
 
 
 
284
  try:
285
- result = dict(record["edge_properties"])
286
  logger.debug(f"Result: {result}")
287
  # Ensure required keys exist with defaults
288
  required_keys = {
@@ -349,24 +409,27 @@ class Neo4JStorage(BaseGraphStorage):
349
  query = f"""MATCH (n:`{node_label}`)
350
  OPTIONAL MATCH (n)-[r]-(connected)
351
  RETURN n, r, connected"""
352
- async with self._driver.session(database=self._DATABASE) as session:
353
  results = await session.run(query)
354
  edges = []
355
- async for record in results:
356
- source_node = record["n"]
357
- connected_node = record["connected"]
 
358
 
359
- source_label = (
360
- list(source_node.labels)[0] if source_node.labels else None
361
- )
362
- target_label = (
363
- list(connected_node.labels)[0]
364
- if connected_node and connected_node.labels
365
- else None
366
- )
367
 
368
- if source_label and target_label:
369
- edges.append((source_label, target_label))
 
 
370
 
371
  return edges
372
 
@@ -427,30 +490,46 @@ class Neo4JStorage(BaseGraphStorage):
427
  ) -> None:
428
  """
429
  Upsert an edge and its properties between two nodes identified by their labels.
 
430
 
431
  Args:
432
  source_node_id (str): Label of the source node (used as identifier)
433
  target_node_id (str): Label of the target node (used as identifier)
434
  edge_data (dict): Dictionary of properties to set on the edge
 
 
 
435
  """
436
  source_label = await self._ensure_label(source_node_id)
437
  target_label = await self._ensure_label(target_node_id)
438
  edge_properties = edge_data
439
 
 
 
 
 
 
 
 
 
 
440
  async def _do_upsert_edge(tx: AsyncManagedTransaction):
441
  query = f"""
442
  MATCH (source:`{source_label}`)
443
  WITH source
444
  MATCH (target:`{target_label}`)
445
- MERGE (source)-[r:DIRECTED]->(target)
446
  SET r += $properties
447
  RETURN r
448
  """
449
  result = await tx.run(query, properties=edge_properties)
450
- record = await result.single()
451
- logger.debug(
452
- f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
453
- )
 
 
 
454
 
455
  try:
456
  async with self._driver.session(database=self._DATABASE) as session:
@@ -463,145 +542,179 @@ class Neo4JStorage(BaseGraphStorage):
463
  print("Implemented but never called.")
464
 
465
  async def get_knowledge_graph(
466
- self, node_label: str, max_depth: int = 5
 
 
 
 
467
  ) -> KnowledgeGraph:
468
  """
469
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
470
  Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
471
  When reducing the number of nodes, the prioritization criteria are as follows:
472
- 1. Label matching nodes take precedence (nodes containing the specified label string)
473
- 2. Followed by nodes directly connected to the matching nodes
474
- 3. Finally, the degree of the nodes
 
475
 
476
  Args:
477
- node_label (str): String to match in node labels (will match any node containing this string in its label)
478
- max_depth (int, optional): Maximum depth of the graph. Defaults to 5.
 
 
479
  Returns:
480
  KnowledgeGraph: Complete connected subgraph for specified node
481
  """
482
  label = node_label.strip('"')
483
- # Escape single quotes to prevent injection attacks
484
- escaped_label = label.replace("'", "\\'")
485
  result = KnowledgeGraph()
486
  seen_nodes = set()
487
  seen_edges = set()
488
 
489
- async with self._driver.session(database=self._DATABASE) as session:
490
  try:
491
  if label == "*":
492
  main_query = """
493
  MATCH (n)
494
  OPTIONAL MATCH (n)-[r]-()
495
  WITH n, count(r) AS degree
 
496
  ORDER BY degree DESC
497
  LIMIT $max_nodes
498
- WITH collect(n) AS nodes
499
- MATCH (a)-[r]->(b)
500
- WHERE a IN nodes AND b IN nodes
501
- RETURN nodes, collect(DISTINCT r) AS relationships
 
 
 
502
  """
503
  result_set = await session.run(
504
- main_query, {"max_nodes": MAX_GRAPH_NODES}
 
505
  )
506
 
507
  else:
508
- validate_query = f"""
509
- MATCH (n)
510
- WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}')
511
- RETURN n LIMIT 1
512
- """
513
- validate_result = await session.run(validate_query)
514
- if not await validate_result.single():
515
- logger.warning(
516
- f"No nodes containing '{label}' in their labels found!"
517
- )
518
- return result
519
-
520
  # Main query uses partial matching
521
- main_query = f"""
522
  MATCH (start)
523
- WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}')
 
 
 
 
 
524
  WITH start
525
- CALL apoc.path.subgraphAll(start, {{
526
- relationshipFilter: '>',
527
  minLevel: 0,
528
- maxLevel: {max_depth},
529
  bfs: true
530
- }})
531
  YIELD nodes, relationships
532
  WITH start, nodes, relationships
533
  UNWIND nodes AS node
534
  OPTIONAL MATCH (node)-[r]-()
535
- WITH node, count(r) AS degree, start, nodes, relationships,
536
- CASE
537
- WHEN id(node) = id(start) THEN 2
538
- WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1
539
- ELSE 0
540
- END AS priority
541
- ORDER BY priority DESC, degree DESC
 
 
542
  LIMIT $max_nodes
543
- WITH collect(node) AS filtered_nodes, nodes, relationships
544
- RETURN filtered_nodes AS nodes,
545
- [rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships
 
 
 
 
546
  """
547
  result_set = await session.run(
548
- main_query, {"max_nodes": MAX_GRAPH_NODES}
 
 
 
 
 
 
 
549
  )
550
 
551
- record = await result_set.single()
552
-
553
- if record:
554
- # Handle nodes (compatible with multi-label cases)
555
- for node in record["nodes"]:
556
- # Use node ID + label combination as unique identifier
557
- node_id = node.id
558
- if node_id not in seen_nodes:
559
- result.nodes.append(
560
- KnowledgeGraphNode(
561
- id=f"{node_id}",
562
- labels=list(node.labels),
563
- properties=dict(node),
 
 
564
  )
565
- )
566
- seen_nodes.add(node_id)
567
-
568
- # Handle relationships (including direction information)
569
- for rel in record["relationships"]:
570
- edge_id = rel.id
571
- if edge_id not in seen_edges:
572
- start = rel.start_node
573
- end = rel.end_node
574
- result.edges.append(
575
- KnowledgeGraphEdge(
576
- id=f"{edge_id}",
577
- type=rel.type,
578
- source=f"{start.id}",
579
- target=f"{end.id}",
580
- properties=dict(rel),
581
  )
582
- )
583
- seen_edges.add(edge_id)
584
 
585
- logger.info(
586
- f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
587
- )
 
 
588
 
589
  except neo4jExceptions.ClientError as e:
590
- logger.error(f"APOC query failed: {str(e)}")
591
- return await self._robust_fallback(label, max_depth)
 
 
 
 
 
 
592
 
593
  return result
594
 
595
  async def _robust_fallback(
596
- self, label: str, max_depth: int
597
  ) -> Dict[str, List[Dict]]:
598
- """Enhanced fallback query solution"""
 
 
 
 
599
  result = {"nodes": [], "edges": []}
600
  visited_nodes = set()
601
  visited_edges = set()
602
 
603
  async def traverse(current_label: str, current_depth: int):
 
604
  if current_depth > max_depth:
 
 
 
 
605
  return
606
 
607
  # Get current node details
@@ -614,46 +727,46 @@ class Neo4JStorage(BaseGraphStorage):
614
  return
615
  visited_nodes.add(node_id)
616
 
617
- # Add node data (with complete labels)
618
- node_data = {k: v for k, v in node.items()}
619
- node_data["labels"] = [
620
- current_label
621
- ] # Assume get_node method returns label information
622
- result["nodes"].append(node_data)
623
 
624
- # Get all outgoing and incoming edges
 
 
625
  query = f"""
626
- MATCH (a)-[r]-(b)
627
- WHERE a:`{current_label}` OR b:`{current_label}`
628
- RETURN a, r, b,
629
- CASE WHEN startNode(r) = a THEN 'OUTGOING' ELSE 'INCOMING' END AS direction
 
630
  """
631
- async with self._driver.session(database=self._DATABASE) as session:
632
- results = await session.run(query)
633
  async for record in results:
634
  # Handle edges
635
  rel = record["r"]
636
  edge_id = f"{rel.id}_{rel.type}"
637
  if edge_id not in visited_edges:
638
- edge_data = dict(rel)
639
- edge_data.update(
640
- {
641
- "source": list(record["a"].labels)[0],
642
- "target": list(record["b"].labels)[0],
643
  "type": rel.type,
644
- "direction": record["direction"],
645
- }
646
- )
647
- result["edges"].append(edge_data)
648
- visited_edges.add(edge_id)
649
-
650
- # Recursively traverse adjacent nodes
651
- next_label = (
652
- list(record["b"].labels)[0]
653
- if record["direction"] == "OUTGOING"
654
- else list(record["a"].labels)[0]
655
- )
656
- await traverse(next_label, current_depth + 1)
657
 
658
  await traverse(label, 0)
659
  return result
@@ -664,7 +777,7 @@ class Neo4JStorage(BaseGraphStorage):
664
  Returns:
665
  ["Person", "Company", ...] # Alphabetically sorted label list
666
  """
667
- async with self._driver.session(database=self._DATABASE) as session:
668
  # Method 1: Direct metadata query (Available for Neo4j 4.3+)
669
  # query = "CALL db.labels() YIELD label RETURN label"
670
 
@@ -679,8 +792,11 @@ class Neo4JStorage(BaseGraphStorage):
679
 
680
  result = await session.run(query)
681
  labels = []
682
- async for record in result:
683
- labels.append(record["label"])
 
 
 
684
  return labels
685
 
686
  @retry(
@@ -763,7 +879,7 @@ class Neo4JStorage(BaseGraphStorage):
763
 
764
  async def _do_delete_edge(tx: AsyncManagedTransaction):
765
  query = f"""
766
- MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`)
767
  DELETE r
768
  """
769
  await tx.run(query)
 
41
  # Set neo4j logger level to ERROR to suppress warning logs
42
  logging.getLogger("neo4j").setLevel(logging.ERROR)
43
 
44
+
45
  @final
46
  @dataclass
47
  class Neo4JStorage(BaseGraphStorage):
 
64
  MAX_CONNECTION_POOL_SIZE = int(
65
  os.environ.get(
66
  "NEO4J_MAX_CONNECTION_POOL_SIZE",
67
+ config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800
68
  )
69
  )
70
  CONNECTION_TIMEOUT = float(
71
  os.environ.get(
72
  "NEO4J_CONNECTION_TIMEOUT",
73
+ config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0
74
  ),
75
  )
76
  CONNECTION_ACQUISITION_TIMEOUT = float(
77
  os.environ.get(
78
  "NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
79
+ config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0
80
+ ),
81
+ )
82
+ MAX_TRANSACTION_RETRY_TIME = float(
83
+ os.environ.get(
84
+ "NEO4J_MAX_TRANSACTION_RETRY_TIME",
85
+ config.get("neo4j", "max_transaction_retry_time", fallback=30.0),
86
  ),
87
  )
88
  DATABASE = os.environ.get(
 
95
  max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
96
  connection_timeout=CONNECTION_TIMEOUT,
97
  connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
98
+ max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
99
  )
100
 
101
  # Try to connect to the database
 
177
 
178
  async def _ensure_label(self, label: str) -> str:
179
  """Ensure a label is valid
180
+
181
  Args:
182
  label: The label to validate
183
  """
184
  clean_label = label.strip('"')
185
+ if not clean_label:
186
+ raise ValueError("Neo4j: Label cannot be empty")
187
  return clean_label
188
 
189
  async def has_node(self, node_id: str) -> bool:
190
  entity_name_label = await self._ensure_label(node_id)
191
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
192
  query = (
193
  f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
194
  )
195
  result = await session.run(query)
196
  single_result = await result.single()
197
+ await result.consume() # Ensure result is fully consumed
198
  logger.debug(
199
  f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
200
  )
 
204
  entity_name_label_source = source_node_id.strip('"')
205
  entity_name_label_target = target_node_id.strip('"')
206
 
207
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
208
  query = (
209
  f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
210
  "RETURN COUNT(r) > 0 AS edgeExists"
211
  )
212
  result = await session.run(query)
213
  single_result = await result.single()
214
+ await result.consume() # Ensure result is fully consumed
215
  logger.debug(
216
  f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
217
  )
 
227
  dict: Node properties if found
228
  None: If node not found
229
  """
230
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
231
  entity_name_label = await self._ensure_label(node_id)
232
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
233
  result = await session.run(query)
234
+ records = await result.fetch(2) # Get up to 2 records to check for duplicates
235
+ await result.consume() # Ensure result is fully consumed
236
+ if len(records) > 1:
237
+ logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.")
238
+ if records:
239
+ node = records[0]["n"]
240
  node_dict = dict(node)
241
  logger.debug(
242
  f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
 
245
  return None
246
 
247
  async def node_degree(self, node_id: str) -> int:
248
+ """Get the degree (number of relationships) of a node with the given label.
249
+ If multiple nodes have the same label, returns the degree of the first node.
250
+ If no node is found, returns 0.
251
+
252
+ Args:
253
+ node_id: The label of the node
254
+
255
+ Returns:
256
+ int: The number of relationships the node has, or 0 if no node found
257
+ """
258
  entity_name_label = node_id.strip('"')
259
 
260
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
261
  query = f"""
262
  MATCH (n:`{entity_name_label}`)
263
+ OPTIONAL MATCH (n)-[r]-()
264
+ RETURN n, COUNT(r) AS degree
265
  """
266
  result = await session.run(query)
267
+ records = await result.fetch(100)
268
+ await result.consume() # Ensure result is fully consumed
269
+
270
+ if not records:
271
+ logger.warning(f"No node found with label '{entity_name_label}'")
272
+ return 0
273
+
274
+ if len(records) > 1:
275
+ logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree")
276
+
277
+ degree = records[0]["degree"]
278
+ logger.debug(
279
+ f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}"
280
+ )
281
+ return degree
282
 
283
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
284
  entity_name_label_source = src_id.strip('"')
 
296
  )
297
  return degrees
298
 
299
+ async def check_duplicate_nodes(self) -> list[tuple[str, int]]:
300
+ """Find all labels that have multiple nodes
301
+
302
+ Returns:
303
+ list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes
304
+ """
305
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
306
+ query = """
307
+ MATCH (n)
308
+ WITH labels(n) as nodeLabels
309
+ UNWIND nodeLabels as label
310
+ WITH label, count(*) as node_count
311
+ WHERE node_count > 1
312
+ RETURN label, node_count
313
+ ORDER BY node_count DESC
314
+ """
315
+ result = await session.run(query)
316
+ duplicates = []
317
+ async for record in result:
318
+ label = record["label"]
319
+ count = record["node_count"]
320
+ logger.info(f"Found {count} nodes with label: {label}")
321
+ duplicates.append((label, count))
322
+ return duplicates
323
+
324
  async def get_edge(
325
  self, source_node_id: str, target_node_id: str
326
  ) -> dict[str, str] | None:
 
328
  entity_name_label_source = source_node_id.strip('"')
329
  entity_name_label_target = target_node_id.strip('"')
330
 
331
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
332
  query = f"""
333
+ MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`)
334
  RETURN properties(r) as edge_properties
 
335
  """
336
 
337
  result = await session.run(query)
338
+ records = await result.fetch(2) # Get up to 2 records to check for duplicates
339
+ if len(records) > 1:
340
+ logger.warning(
341
+ f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
342
+ )
343
+ if records:
344
  try:
345
+ result = dict(records[0]["edge_properties"])
346
  logger.debug(f"Result: {result}")
347
  # Ensure required keys exist with defaults
348
  required_keys = {
 
409
  query = f"""MATCH (n:`{node_label}`)
410
  OPTIONAL MATCH (n)-[r]-(connected)
411
  RETURN n, r, connected"""
412
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
413
  results = await session.run(query)
414
  edges = []
415
+ try:
416
+ async for record in results:
417
+ source_node = record["n"]
418
+ connected_node = record["connected"]
419
 
420
+ source_label = (
421
+ list(source_node.labels)[0] if source_node.labels else None
422
+ )
423
+ target_label = (
424
+ list(connected_node.labels)[0]
425
+ if connected_node and connected_node.labels
426
+ else None
427
+ )
428
 
429
+ if source_label and target_label:
430
+ edges.append((source_label, target_label))
431
+ finally:
432
+ await results.consume() # Ensure results are consumed even if processing fails
433
 
434
  return edges
435
 
 
490
  ) -> None:
491
  """
492
  Upsert an edge and its properties between two nodes identified by their labels.
493
+ Checks if both source and target nodes exist before creating the edge.
494
 
495
  Args:
496
  source_node_id (str): Label of the source node (used as identifier)
497
  target_node_id (str): Label of the target node (used as identifier)
498
  edge_data (dict): Dictionary of properties to set on the edge
499
+
500
+ Raises:
501
+ ValueError: If either source or target node does not exist
502
  """
503
  source_label = await self._ensure_label(source_node_id)
504
  target_label = await self._ensure_label(target_node_id)
505
  edge_properties = edge_data
506
 
507
+ # Check if both nodes exist
508
+ source_exists = await self.has_node(source_label)
509
+ target_exists = await self.has_node(target_label)
510
+
511
+ if not source_exists:
512
+ raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist")
513
+ if not target_exists:
514
+ raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist")
515
+
516
  async def _do_upsert_edge(tx: AsyncManagedTransaction):
517
  query = f"""
518
  MATCH (source:`{source_label}`)
519
  WITH source
520
  MATCH (target:`{target_label}`)
521
+ MERGE (source)-[r:DIRECTED]-(target)
522
  SET r += $properties
523
  RETURN r
524
  """
525
  result = await tx.run(query, properties=edge_properties)
526
+ try:
527
+ record = await result.single()
528
+ logger.debug(
529
+ f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
530
+ )
531
+ finally:
532
+ await result.consume() # Ensure result is consumed
533
 
534
  try:
535
  async with self._driver.session(database=self._DATABASE) as session:
 
542
  print("Implemented but never called.")
543
 
544
  async def get_knowledge_graph(
545
+ self,
546
+ node_label: str,
547
+ max_depth: int = 3,
548
+ min_degree: int = 0,
549
+ inclusive: bool = False,
550
  ) -> KnowledgeGraph:
551
  """
552
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
553
  Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
554
  When reducing the number of nodes, the prioritization criteria are as follows:
555
+ 1. min_degree does not affect nodes directly connected to the matching nodes
556
+ 2. Label matching nodes take precedence
557
+ 3. Followed by nodes directly connected to the matching nodes
558
+ 4. Finally, the degree of the nodes
559
 
560
  Args:
561
+ node_label: Label of the starting node
562
+ max_depth: Maximum depth of the subgraph
563
+ min_degree: Minimum degree of nodes to include. Defaults to 0
564
+ inclusive: Do an inclusive search if true
565
  Returns:
566
  KnowledgeGraph: Complete connected subgraph for specified node
567
  """
568
  label = node_label.strip('"')
 
 
569
  result = KnowledgeGraph()
570
  seen_nodes = set()
571
  seen_edges = set()
572
 
573
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
574
  try:
575
  if label == "*":
576
  main_query = """
577
  MATCH (n)
578
  OPTIONAL MATCH (n)-[r]-()
579
  WITH n, count(r) AS degree
580
+ WHERE degree >= $min_degree
581
  ORDER BY degree DESC
582
  LIMIT $max_nodes
583
+ WITH collect({node: n}) AS filtered_nodes
584
+ UNWIND filtered_nodes AS node_info
585
+ WITH collect(node_info.node) AS kept_nodes, filtered_nodes
586
+ MATCH (a)-[r]-(b)
587
+ WHERE a IN kept_nodes AND b IN kept_nodes
588
+ RETURN filtered_nodes AS node_info,
589
+ collect(DISTINCT r) AS relationships
590
  """
591
  result_set = await session.run(
592
+ main_query,
593
+ {"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree},
594
  )
595
 
596
  else:
 
 
 
 
 
 
 
 
 
 
 
 
597
  # Main query uses partial matching
598
+ main_query = """
599
  MATCH (start)
600
+ WHERE any(label IN labels(start) WHERE
601
+ CASE
602
+ WHEN $inclusive THEN label CONTAINS $label
603
+ ELSE label = $label
604
+ END
605
+ )
606
  WITH start
607
+ CALL apoc.path.subgraphAll(start, {
608
+ relationshipFilter: '',
609
  minLevel: 0,
610
+ maxLevel: $max_depth,
611
  bfs: true
612
+ })
613
  YIELD nodes, relationships
614
  WITH start, nodes, relationships
615
  UNWIND nodes AS node
616
  OPTIONAL MATCH (node)-[r]-()
617
+ WITH node, count(r) AS degree, start, nodes, relationships
618
+ WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
619
+ ORDER BY
620
+ CASE
621
+ WHEN node = start THEN 3
622
+ WHEN EXISTS((start)--(node)) THEN 2
623
+ ELSE 1
624
+ END DESC,
625
+ degree DESC
626
  LIMIT $max_nodes
627
+ WITH collect({node: node}) AS filtered_nodes
628
+ UNWIND filtered_nodes AS node_info
629
+ WITH collect(node_info.node) AS kept_nodes, filtered_nodes
630
+ MATCH (a)-[r]-(b)
631
+ WHERE a IN kept_nodes AND b IN kept_nodes
632
+ RETURN filtered_nodes AS node_info,
633
+ collect(DISTINCT r) AS relationships
634
  """
635
  result_set = await session.run(
636
+ main_query,
637
+ {
638
+ "max_nodes": MAX_GRAPH_NODES,
639
+ "label": label,
640
+ "inclusive": inclusive,
641
+ "max_depth": max_depth,
642
+ "min_degree": min_degree,
643
+ },
644
  )
645
 
646
+ try:
647
+ record = await result_set.single()
648
+
649
+ if record:
650
+ # Handle nodes (compatible with multi-label cases)
651
+ for node_info in record["node_info"]:
652
+ node = node_info["node"]
653
+ node_id = node.id
654
+ if node_id not in seen_nodes:
655
+ result.nodes.append(
656
+ KnowledgeGraphNode(
657
+ id=f"{node_id}",
658
+ labels=list(node.labels),
659
+ properties=dict(node),
660
+ )
661
  )
662
+ seen_nodes.add(node_id)
663
+
664
+ # Handle relationships (including direction information)
665
+ for rel in record["relationships"]:
666
+ edge_id = rel.id
667
+ if edge_id not in seen_edges:
668
+ start = rel.start_node
669
+ end = rel.end_node
670
+ result.edges.append(
671
+ KnowledgeGraphEdge(
672
+ id=f"{edge_id}",
673
+ type=rel.type,
674
+ source=f"{start.id}",
675
+ target=f"{end.id}",
676
+ properties=dict(rel),
677
+ )
678
  )
679
+ seen_edges.add(edge_id)
 
680
 
681
+ logger.info(
682
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
683
+ )
684
+ finally:
685
+ await result_set.consume() # Ensure result set is consumed
686
 
687
  except neo4jExceptions.ClientError as e:
688
+ logger.warning(
689
+ f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation"
690
+ )
691
+ if inclusive:
692
+ logger.warning(
693
+ "Inclusive search mode is not supported in recursive query, using exact matching"
694
+ )
695
+ return await self._robust_fallback(label, max_depth, min_degree)
696
 
697
  return result
698
 
699
  async def _robust_fallback(
700
+ self, label: str, max_depth: int, min_degree: int = 0
701
  ) -> Dict[str, List[Dict]]:
702
+ """
703
+ Fallback implementation when APOC plugin is not available or incompatible.
704
+ This method implements the same functionality as get_knowledge_graph but uses
705
+ only basic Cypher queries and recursive traversal instead of APOC procedures.
706
+ """
707
  result = {"nodes": [], "edges": []}
708
  visited_nodes = set()
709
  visited_edges = set()
710
 
711
  async def traverse(current_label: str, current_depth: int):
712
+ # Check traversal limits
713
  if current_depth > max_depth:
714
+ logger.debug(f"Reached max depth: {max_depth}")
715
+ return
716
+ if len(visited_nodes) >= MAX_GRAPH_NODES:
717
+ logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
718
  return
719
 
720
  # Get current node details
 
727
  return
728
  visited_nodes.add(node_id)
729
 
730
+ # Add node data with label as ID
731
+ result["nodes"].append({
732
+ "id": current_label,
733
+ "labels": current_label,
734
+ "properties": node
735
+ })
736
 
737
+ # Get connected nodes that meet the degree requirement
738
+ # Note: We don't need to check a's degree since it's the current node
739
+ # and was already validated in the previous iteration
740
  query = f"""
741
+ MATCH (a:`{current_label}`)-[r]-(b)
742
+ WITH r, b,
743
+ COUNT((b)--()) AS b_degree
744
+ WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
745
+ RETURN r, b
746
  """
747
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
748
+ results = await session.run(query, {"min_degree": min_degree})
749
  async for record in results:
750
  # Handle edges
751
  rel = record["r"]
752
  edge_id = f"{rel.id}_{rel.type}"
753
  if edge_id not in visited_edges:
754
+ b_node = record["b"]
755
+ if b_node.labels: # Only process if target node has labels
756
+ target_label = list(b_node.labels)[0]
757
+ result["edges"].append({
758
+ "id": f"{current_label}_{target_label}",
759
  "type": rel.type,
760
+ "source": current_label,
761
+ "target": target_label,
762
+ "properties": dict(rel)
763
+ })
764
+ visited_edges.add(edge_id)
765
+
766
+ # Continue traversal
767
+ await traverse(target_label, current_depth + 1)
768
+ else:
769
+ logger.warning(f"Skipping edge {edge_id} due to missing labels on target node")
 
 
 
770
 
771
  await traverse(label, 0)
772
  return result
 
777
  Returns:
778
  ["Person", "Company", ...] # Alphabetically sorted label list
779
  """
780
+ async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
781
  # Method 1: Direct metadata query (Available for Neo4j 4.3+)
782
  # query = "CALL db.labels() YIELD label RETURN label"
783
 
 
792
 
793
  result = await session.run(query)
794
  labels = []
795
+ try:
796
+ async for record in result:
797
+ labels.append(record["label"])
798
+ finally:
799
+ await result.consume() # Ensure results are consumed even if processing fails
800
  return labels
801
 
802
  @retry(
 
879
 
880
  async def _do_delete_edge(tx: AsyncManagedTransaction):
881
  query = f"""
882
+ MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`)
883
  DELETE r
884
  """
885
  await tx.run(query)