yangdx commited on
Commit
706f457
·
1 Parent(s): 1fbf326

Refactor Neo4JStorage to use entity_id for node identification, use entity_type for node label

Browse files
Files changed (1) hide show
  1. lightrag/kg/neo4j_impl.py +95 -192
lightrag/kg/neo4j_impl.py CHANGED
@@ -176,23 +176,6 @@ class Neo4JStorage(BaseGraphStorage):
176
  # Noe4J handles persistence automatically
177
  pass
178
 
179
- def _ensure_label(self, label: str) -> str:
180
- """Ensure a label is valid
181
-
182
- Args:
183
- label: The label to validate
184
-
185
- Returns:
186
- str: The cleaned label
187
-
188
- Raises:
189
- ValueError: If label is empty after cleaning
190
- """
191
- clean_label = label.strip('"')
192
- if not clean_label:
193
- raise ValueError("Neo4j: Label cannot be empty")
194
- return clean_label
195
-
196
  async def has_node(self, node_id: str) -> bool:
197
  """
198
  Check if a node with the given label exists in the database
@@ -207,19 +190,18 @@ class Neo4JStorage(BaseGraphStorage):
207
  ValueError: If node_id is invalid
208
  Exception: If there is an error executing the query
209
  """
210
- entity_name_label = self._ensure_label(node_id)
211
  async with self._driver.session(
212
  database=self._DATABASE, default_access_mode="READ"
213
  ) as session:
214
  try:
215
- query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
216
- result = await session.run(query)
217
  single_result = await result.single()
218
  await result.consume() # Ensure result is fully consumed
219
  return single_result["node_exists"]
220
  except Exception as e:
221
  logger.error(
222
- f"Error checking node existence for {entity_name_label}: {str(e)}"
223
  )
224
  await result.consume() # Ensure results are consumed even on error
225
  raise
@@ -239,24 +221,21 @@ class Neo4JStorage(BaseGraphStorage):
239
  ValueError: If either node_id is invalid
240
  Exception: If there is an error executing the query
241
  """
242
- entity_name_label_source = self._ensure_label(source_node_id)
243
- entity_name_label_target = self._ensure_label(target_node_id)
244
-
245
  async with self._driver.session(
246
  database=self._DATABASE, default_access_mode="READ"
247
  ) as session:
248
  try:
249
  query = (
250
- f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
251
  "RETURN COUNT(r) > 0 AS edgeExists"
252
  )
253
- result = await session.run(query)
254
  single_result = await result.single()
255
  await result.consume() # Ensure result is fully consumed
256
  return single_result["edgeExists"]
257
  except Exception as e:
258
  logger.error(
259
- f"Error checking edge existence between {entity_name_label_source} and {entity_name_label_target}: {str(e)}"
260
  )
261
  await result.consume() # Ensure results are consumed even on error
262
  raise
@@ -275,13 +254,12 @@ class Neo4JStorage(BaseGraphStorage):
275
  ValueError: If node_id is invalid
276
  Exception: If there is an error executing the query
277
  """
278
- entity_name_label = self._ensure_label(node_id)
279
  async with self._driver.session(
280
  database=self._DATABASE, default_access_mode="READ"
281
  ) as session:
282
  try:
283
- query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n"
284
- result = await session.run(query, entity_id=entity_name_label)
285
  try:
286
  records = await result.fetch(
287
  2
@@ -289,20 +267,21 @@ class Neo4JStorage(BaseGraphStorage):
289
 
290
  if len(records) > 1:
291
  logger.warning(
292
- f"Multiple nodes found with label '{entity_name_label}'. Using first node."
293
  )
294
  if records:
295
  node = records[0]["n"]
296
  node_dict = dict(node)
297
- logger.debug(
298
- f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
299
- )
 
300
  return node_dict
301
  return None
302
  finally:
303
  await result.consume() # Ensure result is fully consumed
304
  except Exception as e:
305
- logger.error(f"Error getting node for {entity_name_label}: {str(e)}")
306
  raise
307
 
308
  async def node_degree(self, node_id: str) -> int:
@@ -320,42 +299,33 @@ class Neo4JStorage(BaseGraphStorage):
320
  ValueError: If node_id is invalid
321
  Exception: If there is an error executing the query
322
  """
323
- entity_name_label = self._ensure_label(node_id)
324
-
325
  async with self._driver.session(
326
  database=self._DATABASE, default_access_mode="READ"
327
  ) as session:
328
  try:
329
- query = f"""
330
- MATCH (n:`{entity_name_label}`)
331
  OPTIONAL MATCH (n)-[r]-()
332
- RETURN n, COUNT(r) AS degree
333
  """
334
- result = await session.run(query)
335
  try:
336
- records = await result.fetch(100)
337
 
338
- if not records:
339
  logger.warning(
340
- f"No node found with label '{entity_name_label}'"
341
  )
342
  return 0
343
 
344
- if len(records) > 1:
345
- logger.warning(
346
- f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree"
347
- )
348
-
349
- degree = records[0]["degree"]
350
- logger.debug(
351
- f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}"
352
- )
353
  return degree
354
  finally:
355
  await result.consume() # Ensure result is fully consumed
356
  except Exception as e:
357
  logger.error(
358
- f"Error getting node degree for {entity_name_label}: {str(e)}"
359
  )
360
  raise
361
 
@@ -369,11 +339,8 @@ class Neo4JStorage(BaseGraphStorage):
369
  Returns:
370
  int: Sum of the degrees of both nodes
371
  """
372
- entity_name_label_source = self._ensure_label(src_id)
373
- entity_name_label_target = self._ensure_label(tgt_id)
374
-
375
- src_degree = await self.node_degree(entity_name_label_source)
376
- trg_degree = await self.node_degree(entity_name_label_target)
377
 
378
  # Convert None to 0 for addition
379
  src_degree = 0 if src_degree is None else src_degree
@@ -399,24 +366,20 @@ class Neo4JStorage(BaseGraphStorage):
399
  Exception: If there is an error executing the query
400
  """
401
  try:
402
- entity_name_label_source = self._ensure_label(source_node_id)
403
- entity_name_label_target = self._ensure_label(target_node_id)
404
-
405
  async with self._driver.session(
406
  database=self._DATABASE, default_access_mode="READ"
407
  ) as session:
408
- query = f"""
409
- MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`)
410
  RETURN properties(r) as edge_properties
411
  """
412
-
413
- result = await session.run(query)
414
  try:
415
  records = await result.fetch(2)
416
 
417
  if len(records) > 1:
418
  logger.warning(
419
- f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
420
  )
421
  if records:
422
  try:
@@ -433,7 +396,7 @@ class Neo4JStorage(BaseGraphStorage):
433
  if key not in edge_result:
434
  edge_result[key] = default_value
435
  logger.warning(
436
- f"Edge between {entity_name_label_source} and {entity_name_label_target} "
437
  f"missing {key}, using default: {default_value}"
438
  )
439
 
@@ -443,8 +406,8 @@ class Neo4JStorage(BaseGraphStorage):
443
  return edge_result
444
  except (KeyError, TypeError, ValueError) as e:
445
  logger.error(
446
- f"Error processing edge properties between {entity_name_label_source} "
447
- f"and {entity_name_label_target}: {str(e)}"
448
  )
449
  # Return default edge properties on error
450
  return {
@@ -455,7 +418,7 @@ class Neo4JStorage(BaseGraphStorage):
455
  }
456
 
457
  logger.debug(
458
- f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
459
  )
460
  # Return default edge properties when no edge found
461
  return {
@@ -488,30 +451,30 @@ class Neo4JStorage(BaseGraphStorage):
488
  Exception: If there is an error executing the query
489
  """
490
  try:
491
- node_label = self._ensure_label(source_node_id)
492
-
493
- query = f"""MATCH (n:`{node_label}`)
494
- OPTIONAL MATCH (n)-[r]-(connected)
495
- RETURN n, r, connected"""
496
-
497
  async with self._driver.session(
498
  database=self._DATABASE, default_access_mode="READ"
499
  ) as session:
500
  try:
501
- results = await session.run(query)
502
- edges = []
 
 
 
503
 
 
504
  async for record in results:
505
  source_node = record["n"]
506
  connected_node = record["connected"]
507
 
 
 
 
 
508
  source_label = (
509
- list(source_node.labels)[0] if source_node.labels else None
510
  )
511
  target_label = (
512
- list(connected_node.labels)[0]
513
- if connected_node and connected_node.labels
514
- else None
515
  )
516
 
517
  if source_label and target_label:
@@ -520,7 +483,7 @@ class Neo4JStorage(BaseGraphStorage):
520
  await results.consume() # Ensure results are consumed
521
  return edges
522
  except Exception as e:
523
- logger.error(f"Error getting edges for node {node_label}: {str(e)}")
524
  await results.consume() # Ensure results are consumed even on error
525
  raise
526
  except Exception as e:
@@ -547,8 +510,9 @@ class Neo4JStorage(BaseGraphStorage):
547
  node_id: The unique identifier for the node (used as label)
548
  node_data: Dictionary of node properties
549
  """
550
- label = self._ensure_label(node_id)
551
  properties = node_data
 
 
552
  if "entity_id" not in properties:
553
  raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
554
 
@@ -556,13 +520,14 @@ class Neo4JStorage(BaseGraphStorage):
556
  async with self._driver.session(database=self._DATABASE) as session:
557
 
558
  async def execute_upsert(tx: AsyncManagedTransaction):
559
- query = f"""
560
- MERGE (n:`{label}` {{entity_id: $properties.entity_id}})
561
  SET n += $properties
562
- """
 
563
  result = await tx.run(query, properties=properties)
564
  logger.debug(
565
- f"Upserted node with label '{label}' and properties: {properties}"
566
  )
567
  await result.consume() # Ensure result is fully consumed
568
 
@@ -583,51 +548,6 @@ class Neo4JStorage(BaseGraphStorage):
583
  )
584
  ),
585
  )
586
- async def _get_unique_node_entity_id(self, node_label: str) -> str:
587
- """
588
- Get the entity_id of a node with the given label, ensuring the node is unique.
589
-
590
- Args:
591
- node_label (str): Label of the node to check
592
-
593
- Returns:
594
- str: The entity_id of the unique node
595
-
596
- Raises:
597
- ValueError: If no node with the given label exists or if multiple nodes have the same label
598
- """
599
- async with self._driver.session(
600
- database=self._DATABASE, default_access_mode="READ"
601
- ) as session:
602
- query = f"""
603
- MATCH (n:`{node_label}`)
604
- RETURN n, count(n) as node_count
605
- """
606
- result = await session.run(query)
607
- try:
608
- records = await result.fetch(
609
- 2
610
- ) # We only need to know if there are 0, 1, or >1 nodes
611
-
612
- if not records or records[0]["node_count"] == 0:
613
- raise ValueError(
614
- f"Neo4j: node with label '{node_label}' does not exist"
615
- )
616
-
617
- if records[0]["node_count"] > 1:
618
- raise ValueError(
619
- f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node"
620
- )
621
-
622
- node = records[0]["n"]
623
- if "entity_id" not in node:
624
- raise ValueError(
625
- f"Neo4j: node with label '{node_label}' does not have an entity_id property"
626
- )
627
-
628
- return node["entity_id"]
629
- finally:
630
- await result.consume() # Ensure result is fully consumed
631
 
632
  @retry(
633
  stop=stop_after_attempt(3),
@@ -657,38 +577,30 @@ class Neo4JStorage(BaseGraphStorage):
657
  Raises:
658
  ValueError: If either source or target node does not exist or is not unique
659
  """
660
- source_label = self._ensure_label(source_node_id)
661
- target_label = self._ensure_label(target_node_id)
662
- edge_properties = edge_data
663
-
664
- # Get entity_ids for source and target nodes, ensuring they are unique
665
- source_entity_id = await self._get_unique_node_entity_id(source_label)
666
- target_entity_id = await self._get_unique_node_entity_id(target_label)
667
-
668
  try:
 
669
  async with self._driver.session(database=self._DATABASE) as session:
670
 
671
  async def execute_upsert(tx: AsyncManagedTransaction):
672
- query = f"""
673
- MATCH (source:`{source_label}` {{entity_id: $source_entity_id}})
674
  WITH source
675
- MATCH (target:`{target_label}` {{entity_id: $target_entity_id}})
676
  MERGE (source)-[r:DIRECTED]-(target)
677
  SET r += $properties
678
  RETURN r, source, target
679
  """
680
  result = await tx.run(
681
  query,
682
- source_entity_id=source_entity_id,
683
- target_entity_id=target_entity_id,
684
  properties=edge_properties,
685
  )
686
  try:
687
- records = await result.fetch(100)
688
  if records:
689
  logger.debug(
690
- f"Upserted edge from '{source_label}' (entity_id: {source_entity_id}) "
691
- f"to '{target_label}' (entity_id: {target_entity_id}) "
692
  f"with properties: {edge_properties}"
693
  )
694
  finally:
@@ -726,7 +638,6 @@ class Neo4JStorage(BaseGraphStorage):
726
  Returns:
727
  KnowledgeGraph: Complete connected subgraph for specified node
728
  """
729
- label = node_label.strip('"')
730
  result = KnowledgeGraph()
731
  seen_nodes = set()
732
  seen_edges = set()
@@ -735,7 +646,7 @@ class Neo4JStorage(BaseGraphStorage):
735
  database=self._DATABASE, default_access_mode="READ"
736
  ) as session:
737
  try:
738
- if label == "*":
739
  main_query = """
740
  MATCH (n)
741
  OPTIONAL MATCH (n)-[r]-()
@@ -760,12 +671,11 @@ class Neo4JStorage(BaseGraphStorage):
760
  # Main query uses partial matching
761
  main_query = """
762
  MATCH (start)
763
- WHERE any(label IN labels(start) WHERE
764
  CASE
765
- WHEN $inclusive THEN label CONTAINS $label
766
- ELSE label = $label
767
  END
768
- )
769
  WITH start
770
  CALL apoc.path.subgraphAll(start, {
771
  relationshipFilter: '',
@@ -799,7 +709,7 @@ class Neo4JStorage(BaseGraphStorage):
799
  main_query,
800
  {
801
  "max_nodes": MAX_GRAPH_NODES,
802
- "label": label,
803
  "inclusive": inclusive,
804
  "max_depth": max_depth,
805
  "min_degree": min_degree,
@@ -818,7 +728,7 @@ class Neo4JStorage(BaseGraphStorage):
818
  result.nodes.append(
819
  KnowledgeGraphNode(
820
  id=f"{node_id}",
821
- labels=list(node.labels),
822
  properties=dict(node),
823
  )
824
  )
@@ -849,7 +759,7 @@ class Neo4JStorage(BaseGraphStorage):
849
 
850
  except neo4jExceptions.ClientError as e:
851
  logger.warning(f"APOC plugin error: {str(e)}")
852
- if label != "*":
853
  logger.warning(
854
  "Neo4j: falling back to basic Cypher recursive search..."
855
  )
@@ -857,12 +767,12 @@ class Neo4JStorage(BaseGraphStorage):
857
  logger.warning(
858
  "Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
859
  )
860
- return await self._robust_fallback(label, max_depth, min_degree)
861
 
862
  return result
863
 
864
  async def _robust_fallback(
865
- self, label: str, max_depth: int, min_degree: int = 0
866
  ) -> KnowledgeGraph:
867
  """
868
  Fallback implementation when APOC plugin is not available or incompatible.
@@ -895,12 +805,11 @@ class Neo4JStorage(BaseGraphStorage):
895
  database=self._DATABASE, default_access_mode="READ"
896
  ) as session:
897
  query = """
898
- MATCH (a)-[r]-(b)
899
- WHERE id(a) = toInteger($node_id)
900
  WITH r, b, id(r) as edge_id, id(b) as target_id
901
  RETURN r, b, edge_id, target_id
902
  """
903
- results = await session.run(query, {"node_id": node.id})
904
 
905
  # Get all records and release database connection
906
  records = await results.fetch(
@@ -928,14 +837,14 @@ class Neo4JStorage(BaseGraphStorage):
928
  edge_id = str(record["edge_id"])
929
  if edge_id not in visited_edges:
930
  b_node = record["b"]
931
- target_id = str(record["target_id"])
932
 
933
- if b_node.labels: # Only process if target node has labels
934
  # Create KnowledgeGraphNode for target
935
  target_node = KnowledgeGraphNode(
936
  id=f"{target_id}",
937
- labels=list(b_node.labels),
938
- properties=dict(b_node),
939
  )
940
 
941
  # Create KnowledgeGraphEdge
@@ -961,11 +870,11 @@ class Neo4JStorage(BaseGraphStorage):
961
  async with self._driver.session(
962
  database=self._DATABASE, default_access_mode="READ"
963
  ) as session:
964
- query = f"""
965
- MATCH (n:`{label}`)
966
  RETURN id(n) as node_id, n
967
  """
968
- node_result = await session.run(query)
969
  try:
970
  node_record = await node_result.single()
971
  if not node_record:
@@ -973,9 +882,9 @@ class Neo4JStorage(BaseGraphStorage):
973
 
974
  # Create initial KnowledgeGraphNode
975
  start_node = KnowledgeGraphNode(
976
- id=f"{node_record['node_id']}",
977
- labels=list(node_record["n"].labels),
978
- properties=dict(node_record["n"]),
979
  )
980
  finally:
981
  await node_result.consume() # Ensure results are consumed
@@ -999,11 +908,10 @@ class Neo4JStorage(BaseGraphStorage):
999
 
1000
  # Method 2: Query compatible with older versions
1001
  query = """
1002
- MATCH (n)
1003
- WITH DISTINCT labels(n) AS node_labels
1004
- UNWIND node_labels AS label
1005
- RETURN DISTINCT label
1006
- ORDER BY label
1007
  """
1008
  result = await session.run(query)
1009
  labels = []
@@ -1034,15 +942,13 @@ class Neo4JStorage(BaseGraphStorage):
1034
  Args:
1035
  node_id: The label of the node to delete
1036
  """
1037
- label = self._ensure_label(node_id)
1038
-
1039
  async def _do_delete(tx: AsyncManagedTransaction):
1040
- query = f"""
1041
- MATCH (n:`{label}`)
1042
  DETACH DELETE n
1043
  """
1044
- result = await tx.run(query)
1045
- logger.debug(f"Deleted node with label '{label}'")
1046
  await result.consume() # Ensure result is fully consumed
1047
 
1048
  try:
@@ -1092,16 +998,13 @@ class Neo4JStorage(BaseGraphStorage):
1092
  edges: List of edges to be deleted, each edge is a (source, target) tuple
1093
  """
1094
  for source, target in edges:
1095
- source_label = self._ensure_label(source)
1096
- target_label = self._ensure_label(target)
1097
-
1098
  async def _do_delete_edge(tx: AsyncManagedTransaction):
1099
- query = f"""
1100
- MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`)
1101
  DELETE r
1102
  """
1103
- result = await tx.run(query)
1104
- logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'")
1105
  await result.consume() # Ensure result is fully consumed
1106
 
1107
  try:
 
176
  # Noe4J handles persistence automatically
177
  pass
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  async def has_node(self, node_id: str) -> bool:
180
  """
181
  Check if a node with the given label exists in the database
 
190
  ValueError: If node_id is invalid
191
  Exception: If there is an error executing the query
192
  """
 
193
  async with self._driver.session(
194
  database=self._DATABASE, default_access_mode="READ"
195
  ) as session:
196
  try:
197
+ query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
198
+ result = await session.run(query, entity_id = node_id)
199
  single_result = await result.single()
200
  await result.consume() # Ensure result is fully consumed
201
  return single_result["node_exists"]
202
  except Exception as e:
203
  logger.error(
204
+ f"Error checking node existence for {node_id}: {str(e)}"
205
  )
206
  await result.consume() # Ensure results are consumed even on error
207
  raise
 
221
  ValueError: If either node_id is invalid
222
  Exception: If there is an error executing the query
223
  """
 
 
 
224
  async with self._driver.session(
225
  database=self._DATABASE, default_access_mode="READ"
226
  ) as session:
227
  try:
228
  query = (
229
+ "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
230
  "RETURN COUNT(r) > 0 AS edgeExists"
231
  )
232
+ result = await session.run(query, source_entity_id = source_node_id, target_entity_id = target_node_id)
233
  single_result = await result.single()
234
  await result.consume() # Ensure result is fully consumed
235
  return single_result["edgeExists"]
236
  except Exception as e:
237
  logger.error(
238
+ f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
239
  )
240
  await result.consume() # Ensure results are consumed even on error
241
  raise
 
254
  ValueError: If node_id is invalid
255
  Exception: If there is an error executing the query
256
  """
 
257
  async with self._driver.session(
258
  database=self._DATABASE, default_access_mode="READ"
259
  ) as session:
260
  try:
261
+ query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
262
+ result = await session.run(query, entity_id=node_id)
263
  try:
264
  records = await result.fetch(
265
  2
 
267
 
268
  if len(records) > 1:
269
  logger.warning(
270
+ f"Multiple nodes found with label '{node_id}'. Using first node."
271
  )
272
  if records:
273
  node = records[0]["n"]
274
  node_dict = dict(node)
275
+ # Remove base label from labels list if it exists
276
+ if "labels" in node_dict:
277
+ node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
278
+ logger.debug(f"Neo4j query node {query} return: {node_dict}")
279
  return node_dict
280
  return None
281
  finally:
282
  await result.consume() # Ensure result is fully consumed
283
  except Exception as e:
284
+ logger.error(f"Error getting node for {node_id}: {str(e)}")
285
  raise
286
 
287
  async def node_degree(self, node_id: str) -> int:
 
299
  ValueError: If node_id is invalid
300
  Exception: If there is an error executing the query
301
  """
 
 
302
  async with self._driver.session(
303
  database=self._DATABASE, default_access_mode="READ"
304
  ) as session:
305
  try:
306
+ query = """
307
+ MATCH (n:base {entity_id: $entity_id})
308
  OPTIONAL MATCH (n)-[r]-()
309
+ RETURN COUNT(r) AS degree
310
  """
311
+ result = await session.run(query, entity_id = node_id)
312
  try:
313
+ record = await result.single()
314
 
315
+ if not record:
316
  logger.warning(
317
+ f"No node found with label '{node_id}'"
318
  )
319
  return 0
320
 
321
+ degree = record["degree"]
322
+ logger.debug("Neo4j query node degree for {node_id} return: {degree}")
 
 
 
 
 
 
 
323
  return degree
324
  finally:
325
  await result.consume() # Ensure result is fully consumed
326
  except Exception as e:
327
  logger.error(
328
+ f"Error getting node degree for {node_id}: {str(e)}"
329
  )
330
  raise
331
 
 
339
  Returns:
340
  int: Sum of the degrees of both nodes
341
  """
342
+ src_degree = await self.node_degree(src_id)
343
+ trg_degree = await self.node_degree(tgt_id)
 
 
 
344
 
345
  # Convert None to 0 for addition
346
  src_degree = 0 if src_degree is None else src_degree
 
366
  Exception: If there is an error executing the query
367
  """
368
  try:
 
 
 
369
  async with self._driver.session(
370
  database=self._DATABASE, default_access_mode="READ"
371
  ) as session:
372
+ query = """
373
+ MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
374
  RETURN properties(r) as edge_properties
375
  """
376
+ result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id)
 
377
  try:
378
  records = await result.fetch(2)
379
 
380
  if len(records) > 1:
381
  logger.warning(
382
+ f"Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge."
383
  )
384
  if records:
385
  try:
 
396
  if key not in edge_result:
397
  edge_result[key] = default_value
398
  logger.warning(
399
+ f"Edge between {source_node_id} and {target_node_id} "
400
  f"missing {key}, using default: {default_value}"
401
  )
402
 
 
406
  return edge_result
407
  except (KeyError, TypeError, ValueError) as e:
408
  logger.error(
409
+ f"Error processing edge properties between {source_node_id} "
410
+ f"and {target_node_id}: {str(e)}"
411
  )
412
  # Return default edge properties on error
413
  return {
 
418
  }
419
 
420
  logger.debug(
421
+ f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
422
  )
423
  # Return default edge properties when no edge found
424
  return {
 
451
  Exception: If there is an error executing the query
452
  """
453
  try:
 
 
 
 
 
 
454
  async with self._driver.session(
455
  database=self._DATABASE, default_access_mode="READ"
456
  ) as session:
457
  try:
458
+ query = """MATCH (n:base {entity_id: $entity_id})
459
+ OPTIONAL MATCH (n)-[r]-(connected:base)
460
+ WHERE connected.entity_id IS NOT NULL
461
+ RETURN n, r, connected"""
462
+ results = await session.run(query, entity_id=source_node_id)
463
 
464
+ edges = []
465
  async for record in results:
466
  source_node = record["n"]
467
  connected_node = record["connected"]
468
 
469
+ # Skip if either node is None
470
+ if not source_node or not connected_node:
471
+ continue
472
+
473
  source_label = (
474
+ source_node.get("entity_id") if source_node.get("entity_id") else None
475
  )
476
  target_label = (
477
+ connected_node.get("entity_id") if connected_node.get("entity_id") else None
 
 
478
  )
479
 
480
  if source_label and target_label:
 
483
  await results.consume() # Ensure results are consumed
484
  return edges
485
  except Exception as e:
486
+ logger.error(f"Error getting edges for node {source_node_id}: {str(e)}")
487
  await results.consume() # Ensure results are consumed even on error
488
  raise
489
  except Exception as e:
 
510
  node_id: The unique identifier for the node (used as label)
511
  node_data: Dictionary of node properties
512
  """
 
513
  properties = node_data
514
+ entity_type = properties["entity_type"]
515
+ entity_id = properties["entity_id"]
516
  if "entity_id" not in properties:
517
  raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
518
 
 
520
  async with self._driver.session(database=self._DATABASE) as session:
521
 
522
  async def execute_upsert(tx: AsyncManagedTransaction):
523
+ query = """
524
+ MERGE (n:base {entity_id: $properties.entity_id})
525
  SET n += $properties
526
+ SET n:`%s`
527
+ """ % entity_type
528
  result = await tx.run(query, properties=properties)
529
  logger.debug(
530
+ f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
531
  )
532
  await result.consume() # Ensure result is fully consumed
533
 
 
548
  )
549
  ),
550
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
 
552
  @retry(
553
  stop=stop_after_attempt(3),
 
577
  Raises:
578
  ValueError: If either source or target node does not exist or is not unique
579
  """
 
 
 
 
 
 
 
 
580
  try:
581
+ edge_properties = edge_data
582
  async with self._driver.session(database=self._DATABASE) as session:
583
 
584
  async def execute_upsert(tx: AsyncManagedTransaction):
585
+ query = """
586
+ MATCH (source:base {entity_id: $source_entity_id})
587
  WITH source
588
+ MATCH (target:base {entity_id: $target_entity_id})
589
  MERGE (source)-[r:DIRECTED]-(target)
590
  SET r += $properties
591
  RETURN r, source, target
592
  """
593
  result = await tx.run(
594
  query,
595
+ source_entity_id=source_node_id,
596
+ target_entity_id=target_node_id,
597
  properties=edge_properties,
598
  )
599
  try:
600
+ records = await result.fetch(2)
601
  if records:
602
  logger.debug(
603
+ f"Upserted edge from '{source_node_id}' to '{target_node_id}'"
 
604
  f"with properties: {edge_properties}"
605
  )
606
  finally:
 
638
  Returns:
639
  KnowledgeGraph: Complete connected subgraph for specified node
640
  """
 
641
  result = KnowledgeGraph()
642
  seen_nodes = set()
643
  seen_edges = set()
 
646
  database=self._DATABASE, default_access_mode="READ"
647
  ) as session:
648
  try:
649
+ if node_label == "*":
650
  main_query = """
651
  MATCH (n)
652
  OPTIONAL MATCH (n)-[r]-()
 
671
  # Main query uses partial matching
672
  main_query = """
673
  MATCH (start)
674
+ WHERE
675
  CASE
676
+ WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
677
+ ELSE start.entity_id = $entity_id
678
  END
 
679
  WITH start
680
  CALL apoc.path.subgraphAll(start, {
681
  relationshipFilter: '',
 
709
  main_query,
710
  {
711
  "max_nodes": MAX_GRAPH_NODES,
712
+ "entity_id": node_label,
713
  "inclusive": inclusive,
714
  "max_depth": max_depth,
715
  "min_degree": min_degree,
 
728
  result.nodes.append(
729
  KnowledgeGraphNode(
730
  id=f"{node_id}",
731
+ labels=[label for label in node.labels if label != "base"],
732
  properties=dict(node),
733
  )
734
  )
 
759
 
760
  except neo4jExceptions.ClientError as e:
761
  logger.warning(f"APOC plugin error: {str(e)}")
762
+ if node_label != "*":
763
  logger.warning(
764
  "Neo4j: falling back to basic Cypher recursive search..."
765
  )
 
767
  logger.warning(
768
  "Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
769
  )
770
+ return await self._robust_fallback(node_label, max_depth, min_degree)
771
 
772
  return result
773
 
774
  async def _robust_fallback(
775
+ self, node_label: str, max_depth: int, min_degree: int = 0
776
  ) -> KnowledgeGraph:
777
  """
778
  Fallback implementation when APOC plugin is not available or incompatible.
 
805
  database=self._DATABASE, default_access_mode="READ"
806
  ) as session:
807
  query = """
808
+ MATCH (a:base {entity_id: $entity_id})-[r]-(b)
 
809
  WITH r, b, id(r) as edge_id, id(b) as target_id
810
  RETURN r, b, edge_id, target_id
811
  """
812
+ results = await session.run(query, entity_id=node.id)
813
 
814
  # Get all records and release database connection
815
  records = await results.fetch(
 
837
  edge_id = str(record["edge_id"])
838
  if edge_id not in visited_edges:
839
  b_node = record["b"]
840
+ target_id = b_node.get("entity_id")
841
 
842
+ if target_id: # Only process if target node has entity_id
843
  # Create KnowledgeGraphNode for target
844
  target_node = KnowledgeGraphNode(
845
  id=f"{target_id}",
846
+ labels=[label for label in b_node.labels if label != "base"],
847
+ properties=dict(b_node.properties),
848
  )
849
 
850
  # Create KnowledgeGraphEdge
 
870
  async with self._driver.session(
871
  database=self._DATABASE, default_access_mode="READ"
872
  ) as session:
873
+ query = """
874
+ MATCH (n:base {entity_id: $entity_id})
875
  RETURN id(n) as node_id, n
876
  """
877
+ node_result = await session.run(query, entity_id=node_label)
878
  try:
879
  node_record = await node_result.single()
880
  if not node_record:
 
882
 
883
  # Create initial KnowledgeGraphNode
884
  start_node = KnowledgeGraphNode(
885
+ id=f"{node_record['n'].get('entity_id')}",
886
+ labels=[label for label in node_record["n"].labels if label != "base"],
887
+ properties=dict(node_record["n"].properties),
888
  )
889
  finally:
890
  await node_result.consume() # Ensure results are consumed
 
908
 
909
  # Method 2: Query compatible with older versions
910
  query = """
911
+ MATCH (n)
912
+ WHERE n.entity_id IS NOT NULL
913
+ RETURN DISTINCT n.entity_id AS label
914
+ ORDER BY label
 
915
  """
916
  result = await session.run(query)
917
  labels = []
 
942
  Args:
943
  node_id: The label of the node to delete
944
  """
 
 
945
  async def _do_delete(tx: AsyncManagedTransaction):
946
+ query = """
947
+ MATCH (n:base {entity_id: $entity_id})
948
  DETACH DELETE n
949
  """
950
+ result = await tx.run(query, entity_id=node_id)
951
+ logger.debug(f"Deleted node with label '{node_id}'")
952
  await result.consume() # Ensure result is fully consumed
953
 
954
  try:
 
998
  edges: List of edges to be deleted, each edge is a (source, target) tuple
999
  """
1000
  for source, target in edges:
 
 
 
1001
  async def _do_delete_edge(tx: AsyncManagedTransaction):
1002
+ query = """
1003
+ MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
1004
  DELETE r
1005
  """
1006
+ result = await tx.run(query, source_entity_id=source, target_entity_id=target)
1007
+ logger.debug(f"Deleted edge from '{source}' to '{target}'")
1008
  await result.consume() # Ensure result is fully consumed
1009
 
1010
  try: