gzdaniel commited on
Commit
e97b5af
·
1 Parent(s): 2391b37

Replace tenacity retries with manual Memgraph transaction retries

Browse files

- Implement manual retry logic
- Add exponential backoff with jitter
- Improve error handling for transient errors

Files changed (2) hide show
  1. lightrag/kg/memgraph_impl.py +143 -85
  2. lightrag/operate.py +1 -1
lightrag/kg/memgraph_impl.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  from dataclasses import dataclass
3
  from typing import final
4
  import configparser
@@ -11,20 +13,11 @@ import pipmaster as pm
11
 
12
  if not pm.is_installed("neo4j"):
13
  pm.install("neo4j")
14
- if not pm.is_installed("tenacity"):
15
- pm.install("tenacity")
16
-
17
  from neo4j import (
18
  AsyncGraphDatabase,
19
  AsyncManagedTransaction,
20
  )
21
- from neo4j.exceptions import TransientError
22
- from tenacity import (
23
- retry,
24
- stop_after_attempt,
25
- wait_exponential,
26
- retry_if_exception_type,
27
- )
28
 
29
  from dotenv import load_dotenv
30
 
@@ -111,25 +104,6 @@ class MemgraphStorage(BaseGraphStorage):
111
  # Memgraph handles persistence automatically
112
  pass
113
 
114
- @retry(
115
- stop=stop_after_attempt(5),
116
- wait=wait_exponential(multiplier=1, min=1, max=10),
117
- retry=retry_if_exception_type(TransientError),
118
- reraise=True,
119
- )
120
- async def _execute_write_with_retry(self, session, operation_func):
121
- """
122
- Execute a write operation with retry logic for Memgraph transient errors.
123
-
124
- Args:
125
- session: Neo4j session
126
- operation_func: Async function that takes a transaction and executes the operation
127
-
128
- Raises:
129
- TransientError: If all retry attempts fail
130
- """
131
- return await session.execute_write(operation_func)
132
-
133
  async def has_node(self, node_id: str) -> bool:
134
  """
135
  Check if a node exists in the graph.
@@ -463,7 +437,7 @@ class MemgraphStorage(BaseGraphStorage):
463
 
464
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
465
  """
466
- Upsert a node in the Memgraph database with retry logic for transient errors.
467
 
468
  Args:
469
  node_id: The unique identifier for the node (used as label)
@@ -480,36 +454,77 @@ class MemgraphStorage(BaseGraphStorage):
480
  "Memgraph: node properties must contain an 'entity_id' field"
481
  )
482
 
483
- try:
484
- async with self._driver.session(database=self._DATABASE) as session:
485
- workspace_label = self._get_workspace_label()
 
 
486
 
487
- async def execute_upsert(tx: AsyncManagedTransaction):
488
- query = f"""
489
- MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
490
- SET n += $properties
491
- SET n:`{entity_type}`
492
- """
493
- result = await tx.run(
494
- query, entity_id=node_id, properties=properties
495
- )
496
- await result.consume() # Ensure result is fully consumed
497
 
498
- await self._execute_write_with_retry(session, execute_upsert)
499
- except TransientError as e:
500
- logger.error(
501
- f"Memgraph transient error during node upsert after retries: {str(e)}"
502
- )
503
- raise
504
- except Exception as e:
505
- logger.error(f"Error during node upsert: {str(e)}")
506
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
  async def upsert_edge(
509
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
510
  ) -> None:
511
  """
512
- Upsert an edge and its properties between two nodes identified by their labels with retry logic for transient errors.
513
  Ensures both source and target nodes exist and are unique before creating the edge.
514
  Uses entity_id property to uniquely identify nodes.
515
 
@@ -525,40 +540,83 @@ class MemgraphStorage(BaseGraphStorage):
525
  raise RuntimeError(
526
  "Memgraph driver is not initialized. Call 'await initialize()' first."
527
  )
528
- try:
529
- edge_properties = edge_data
530
- async with self._driver.session(database=self._DATABASE) as session:
531
 
532
- async def execute_upsert(tx: AsyncManagedTransaction):
533
- workspace_label = self._get_workspace_label()
534
- query = f"""
535
- MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
536
- WITH source
537
- MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
538
- MERGE (source)-[r:DIRECTED]-(target)
539
- SET r += $properties
540
- RETURN r, source, target
541
- """
542
- result = await tx.run(
543
- query,
544
- source_entity_id=source_node_id,
545
- target_entity_id=target_node_id,
546
- properties=edge_properties,
547
- )
548
- try:
549
- await result.fetch(2)
550
- finally:
551
- await result.consume() # Ensure result is consumed
552
 
553
- await self._execute_write_with_retry(session, execute_upsert)
554
- except TransientError as e:
555
- logger.error(
556
- f"Memgraph transient error during edge upsert after retries: {str(e)}"
557
- )
558
- raise
559
- except Exception as e:
560
- logger.error(f"Error during edge upsert: {str(e)}")
561
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
  async def delete_node(self, node_id: str) -> None:
564
  """Delete a node with the specified label
 
1
  import os
2
+ import asyncio
3
+ import random
4
  from dataclasses import dataclass
5
  from typing import final
6
  import configparser
 
13
 
14
  if not pm.is_installed("neo4j"):
15
  pm.install("neo4j")
 
 
 
16
  from neo4j import (
17
  AsyncGraphDatabase,
18
  AsyncManagedTransaction,
19
  )
20
+ from neo4j.exceptions import TransientError, ResultFailedError
 
 
 
 
 
 
21
 
22
  from dotenv import load_dotenv
23
 
 
104
  # Memgraph handles persistence automatically
105
  pass
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  async def has_node(self, node_id: str) -> bool:
108
  """
109
  Check if a node exists in the graph.
 
437
 
438
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
439
  """
440
+ Upsert a node in the Memgraph database with manual transaction-level retry logic for transient errors.
441
 
442
  Args:
443
  node_id: The unique identifier for the node (used as label)
 
454
  "Memgraph: node properties must contain an 'entity_id' field"
455
  )
456
 
457
+ # Manual transaction-level retry following official Memgraph documentation
458
+ max_retries = 100
459
+ initial_wait_time = 0.2
460
+ backoff_factor = 1.1
461
+ jitter_factor = 0.1
462
 
463
+ for attempt in range(max_retries):
464
+ try:
465
+ logger.debug(
466
+ f"Attempting node upsert, attempt {attempt + 1}/{max_retries}"
467
+ )
468
+ async with self._driver.session(database=self._DATABASE) as session:
469
+ workspace_label = self._get_workspace_label()
 
 
 
470
 
471
+ async def execute_upsert(tx: AsyncManagedTransaction):
472
+ query = f"""
473
+ MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
474
+ SET n += $properties
475
+ SET n:`{entity_type}`
476
+ """
477
+ result = await tx.run(
478
+ query, entity_id=node_id, properties=properties
479
+ )
480
+ await result.consume() # Ensure result is fully consumed
481
+
482
+ await session.execute_write(execute_upsert)
483
+ break # Success - exit retry loop
484
+
485
+ except (TransientError, ResultFailedError) as e:
486
+ # Check if the root cause is a TransientError
487
+ root_cause = e
488
+ while hasattr(root_cause, "__cause__") and root_cause.__cause__:
489
+ root_cause = root_cause.__cause__
490
+
491
+ # Check if this is a transient error that should be retried
492
+ is_transient = (
493
+ isinstance(root_cause, TransientError)
494
+ or isinstance(e, TransientError)
495
+ or "TransientError" in str(e)
496
+ or "Cannot resolve conflicting transactions" in str(e)
497
+ )
498
+
499
+ if is_transient:
500
+ if attempt < max_retries - 1:
501
+ # Calculate wait time with exponential backoff and jitter
502
+ jitter = random.uniform(0, jitter_factor) * initial_wait_time
503
+ wait_time = (
504
+ initial_wait_time * (backoff_factor**attempt) + jitter
505
+ )
506
+ logger.warning(
507
+ f"Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
508
+ )
509
+ await asyncio.sleep(wait_time)
510
+ else:
511
+ logger.error(
512
+ f"Memgraph transient error during node upsert after {max_retries} retries: {str(e)}"
513
+ )
514
+ raise
515
+ else:
516
+ # Non-transient error, don't retry
517
+ logger.error(f"Non-transient error during node upsert: {str(e)}")
518
+ raise
519
+ except Exception as e:
520
+ logger.error(f"Unexpected error during node upsert: {str(e)}")
521
+ raise
522
 
523
  async def upsert_edge(
524
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
525
  ) -> None:
526
  """
527
+ Upsert an edge and its properties between two nodes identified by their labels with manual transaction-level retry logic for transient errors.
528
  Ensures both source and target nodes exist and are unique before creating the edge.
529
  Uses entity_id property to uniquely identify nodes.
530
 
 
540
  raise RuntimeError(
541
  "Memgraph driver is not initialized. Call 'await initialize()' first."
542
  )
 
 
 
543
 
544
+ edge_properties = edge_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
+ # Manual transaction-level retry following official Memgraph documentation
547
+ max_retries = 100
548
+ initial_wait_time = 0.2
549
+ backoff_factor = 1.1
550
+ jitter_factor = 0.1
551
+
552
+ for attempt in range(max_retries):
553
+ try:
554
+ logger.debug(
555
+ f"Attempting edge upsert, attempt {attempt + 1}/{max_retries}"
556
+ )
557
+ async with self._driver.session(database=self._DATABASE) as session:
558
+
559
+ async def execute_upsert(tx: AsyncManagedTransaction):
560
+ workspace_label = self._get_workspace_label()
561
+ query = f"""
562
+ MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
563
+ WITH source
564
+ MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
565
+ MERGE (source)-[r:DIRECTED]-(target)
566
+ SET r += $properties
567
+ RETURN r, source, target
568
+ """
569
+ result = await tx.run(
570
+ query,
571
+ source_entity_id=source_node_id,
572
+ target_entity_id=target_node_id,
573
+ properties=edge_properties,
574
+ )
575
+ try:
576
+ await result.fetch(2)
577
+ finally:
578
+ await result.consume() # Ensure result is consumed
579
+
580
+ await session.execute_write(execute_upsert)
581
+ break # Success - exit retry loop
582
+
583
+ except (TransientError, ResultFailedError) as e:
584
+ # Check if the root cause is a TransientError
585
+ root_cause = e
586
+ while hasattr(root_cause, "__cause__") and root_cause.__cause__:
587
+ root_cause = root_cause.__cause__
588
+
589
+ # Check if this is a transient error that should be retried
590
+ is_transient = (
591
+ isinstance(root_cause, TransientError)
592
+ or isinstance(e, TransientError)
593
+ or "TransientError" in str(e)
594
+ or "Cannot resolve conflicting transactions" in str(e)
595
+ )
596
+
597
+ if is_transient:
598
+ if attempt < max_retries - 1:
599
+ # Calculate wait time with exponential backoff and jitter
600
+ jitter = random.uniform(0, jitter_factor) * initial_wait_time
601
+ wait_time = (
602
+ initial_wait_time * (backoff_factor**attempt) + jitter
603
+ )
604
+ logger.warning(
605
+ f"Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}"
606
+ )
607
+ await asyncio.sleep(wait_time)
608
+ else:
609
+ logger.error(
610
+ f"Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}"
611
+ )
612
+ raise
613
+ else:
614
+ # Non-transient error, don't retry
615
+ logger.error(f"Non-transient error during edge upsert: {str(e)}")
616
+ raise
617
+ except Exception as e:
618
+ logger.error(f"Unexpected error during edge upsert: {str(e)}")
619
+ raise
620
 
621
  async def delete_node(self, node_id: str) -> None:
622
  """Delete a node with the specified label
lightrag/operate.py CHANGED
@@ -1285,7 +1285,7 @@ async def merge_nodes_and_edges(
1285
  namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
1286
  # Sort the edge_key components to ensure consistent lock key generation
1287
  sorted_edge_key = sorted([edge_key[0], edge_key[1]])
1288
- logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
1289
  async with get_storage_keyed_lock(
1290
  f"{sorted_edge_key[0]}-{sorted_edge_key[1]}",
1291
  namespace=namespace,
 
1285
  namespace = f"{workspace}:GraphDB" if workspace else "GraphDB"
1286
  # Sort the edge_key components to ensure consistent lock key generation
1287
  sorted_edge_key = sorted([edge_key[0], edge_key[1]])
1288
+ # logger.info(f"Processing edge: {sorted_edge_key[0]} - {sorted_edge_key[1]}")
1289
  async with get_storage_keyed_lock(
1290
  f"{sorted_edge_key[0]}-{sorted_edge_key[1]}",
1291
  namespace=namespace,