yangdx commited on
Commit
c71b14f
·
1 Parent(s): 9b765c7

Fix linting

Browse files
lightrag/kg/neo4j_impl.py CHANGED
@@ -181,10 +181,10 @@ class Neo4JStorage(BaseGraphStorage):
181
 
182
  Args:
183
  label: The label to validate
184
-
185
  Returns:
186
  str: The cleaned label
187
-
188
  Raises:
189
  ValueError: If label is empty after cleaning
190
  """
@@ -283,7 +283,9 @@ class Neo4JStorage(BaseGraphStorage):
283
  query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n"
284
  result = await session.run(query, entity_id=entity_name_label)
285
  try:
286
- records = await result.fetch(2) # Get 2 records for duplication check
 
 
287
 
288
  if len(records) > 1:
289
  logger.warning(
@@ -552,6 +554,7 @@ class Neo4JStorage(BaseGraphStorage):
552
 
553
  try:
554
  async with self._driver.session(database=self._DATABASE) as session:
 
555
  async def execute_upsert(tx: AsyncManagedTransaction):
556
  query = f"""
557
  MERGE (n:`{label}` {{entity_id: $properties.entity_id}})
@@ -562,7 +565,7 @@ class Neo4JStorage(BaseGraphStorage):
562
  f"Upserted node with label '{label}' and properties: {properties}"
563
  )
564
  await result.consume() # Ensure result is fully consumed
565
-
566
  await session.execute_write(execute_upsert)
567
  except Exception as e:
568
  logger.error(f"Error during upsert: {str(e)}")
@@ -602,18 +605,26 @@ class Neo4JStorage(BaseGraphStorage):
602
  """
603
  result = await session.run(query)
604
  try:
605
- records = await result.fetch(2) # We only need to know if there are 0, 1, or >1 nodes
606
-
 
 
607
  if not records or records[0]["node_count"] == 0:
608
- raise ValueError(f"Neo4j: node with label '{node_label}' does not exist")
609
-
 
 
610
  if records[0]["node_count"] > 1:
611
- raise ValueError(f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node")
612
-
 
 
613
  node = records[0]["n"]
614
  if "entity_id" not in node:
615
- raise ValueError(f"Neo4j: node with label '{node_label}' does not have an entity_id property")
616
-
 
 
617
  return node["entity_id"]
618
  finally:
619
  await result.consume() # Ensure result is fully consumed
@@ -656,6 +667,7 @@ class Neo4JStorage(BaseGraphStorage):
656
 
657
  try:
658
  async with self._driver.session(database=self._DATABASE) as session:
 
659
  async def execute_upsert(tx: AsyncManagedTransaction):
660
  query = f"""
661
  MATCH (source:`{source_label}` {{entity_id: $source_entity_id}})
@@ -666,10 +678,10 @@ class Neo4JStorage(BaseGraphStorage):
666
  RETURN r, source, target
667
  """
668
  result = await tx.run(
669
- query,
670
  source_entity_id=source_entity_id,
671
  target_entity_id=target_entity_id,
672
- properties=edge_properties
673
  )
674
  try:
675
  records = await result.fetch(100)
@@ -681,7 +693,7 @@ class Neo4JStorage(BaseGraphStorage):
681
  )
682
  finally:
683
  await result.consume() # Ensure result is consumed
684
-
685
  await session.execute_write(execute_upsert)
686
  except Exception as e:
687
  logger.error(f"Error during edge upsert: {str(e)}")
@@ -891,7 +903,9 @@ class Neo4JStorage(BaseGraphStorage):
891
  results = await session.run(query, {"node_id": node.id})
892
 
893
  # Get all records and release database connection
894
- records = await results.fetch(1000) # Max neighbour nodes we can handled
 
 
895
  await results.consume() # Ensure results are consumed
896
 
897
  # Nodes not connected to start node need to check degree
 
181
 
182
  Args:
183
  label: The label to validate
184
+
185
  Returns:
186
  str: The cleaned label
187
+
188
  Raises:
189
  ValueError: If label is empty after cleaning
190
  """
 
283
  query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n"
284
  result = await session.run(query, entity_id=entity_name_label)
285
  try:
286
+ records = await result.fetch(
287
+ 2
288
+ ) # Get 2 records for duplication check
289
 
290
  if len(records) > 1:
291
  logger.warning(
 
554
 
555
  try:
556
  async with self._driver.session(database=self._DATABASE) as session:
557
+
558
  async def execute_upsert(tx: AsyncManagedTransaction):
559
  query = f"""
560
  MERGE (n:`{label}` {{entity_id: $properties.entity_id}})
 
565
  f"Upserted node with label '{label}' and properties: {properties}"
566
  )
567
  await result.consume() # Ensure result is fully consumed
568
+
569
  await session.execute_write(execute_upsert)
570
  except Exception as e:
571
  logger.error(f"Error during upsert: {str(e)}")
 
605
  """
606
  result = await session.run(query)
607
  try:
608
+ records = await result.fetch(
609
+ 2
610
+ ) # We only need to know if there are 0, 1, or >1 nodes
611
+
612
  if not records or records[0]["node_count"] == 0:
613
+ raise ValueError(
614
+ f"Neo4j: node with label '{node_label}' does not exist"
615
+ )
616
+
617
  if records[0]["node_count"] > 1:
618
+ raise ValueError(
619
+ f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node"
620
+ )
621
+
622
  node = records[0]["n"]
623
  if "entity_id" not in node:
624
+ raise ValueError(
625
+ f"Neo4j: node with label '{node_label}' does not have an entity_id property"
626
+ )
627
+
628
  return node["entity_id"]
629
  finally:
630
  await result.consume() # Ensure result is fully consumed
 
667
 
668
  try:
669
  async with self._driver.session(database=self._DATABASE) as session:
670
+
671
  async def execute_upsert(tx: AsyncManagedTransaction):
672
  query = f"""
673
  MATCH (source:`{source_label}` {{entity_id: $source_entity_id}})
 
678
  RETURN r, source, target
679
  """
680
  result = await tx.run(
681
+ query,
682
  source_entity_id=source_entity_id,
683
  target_entity_id=target_entity_id,
684
+ properties=edge_properties,
685
  )
686
  try:
687
  records = await result.fetch(100)
 
693
  )
694
  finally:
695
  await result.consume() # Ensure result is consumed
696
+
697
  await session.execute_write(execute_upsert)
698
  except Exception as e:
699
  logger.error(f"Error during edge upsert: {str(e)}")
 
903
  results = await session.run(query, {"node_id": node.id})
904
 
905
  # Get all records and release database connection
906
+ records = await results.fetch(
907
+ 1000
908
+ ) # Max neighbour nodes we can handled
909
  await results.consume() # Ensure results are consumed
910
 
911
  # Nodes not connected to start node need to check degree
lightrag/kg/shared_storage.py CHANGED
@@ -11,7 +11,7 @@ def direct_log(message, level="INFO", enable_output: bool = True):
11
  """
12
  Log a message directly to stderr to ensure visibility in all processes,
13
  including the Gunicorn master process.
14
-
15
  Args:
16
  message: The message to log
17
  level: Log level (default: "INFO")
@@ -44,7 +44,13 @@ _graph_db_lock: Optional[LockType] = None
44
  class UnifiedLock(Generic[T]):
45
  """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
46
 
47
- def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool, name: str = "unnamed", enable_logging: bool = True):
 
 
 
 
 
 
48
  self._lock = lock
49
  self._is_async = is_async
50
  self._pid = os.getpid() # for debug only
@@ -53,27 +59,47 @@ class UnifiedLock(Generic[T]):
53
 
54
  async def __aenter__(self) -> "UnifiedLock[T]":
55
  try:
56
- direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging)
 
 
 
57
  if self._is_async:
58
  await self._lock.acquire()
59
  else:
60
  self._lock.acquire()
61
- direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", enable_output=self._enable_logging)
 
 
 
62
  return self
63
  except Exception as e:
64
- direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging)
 
 
 
 
65
  raise
66
 
67
  async def __aexit__(self, exc_type, exc_val, exc_tb):
68
  try:
69
- direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging)
 
 
 
70
  if self._is_async:
71
  self._lock.release()
72
  else:
73
  self._lock.release()
74
- direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", enable_output=self._enable_logging)
 
 
 
75
  except Exception as e:
76
- direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging)
 
 
 
 
77
  raise
78
 
79
  def __enter__(self) -> "UnifiedLock[T]":
@@ -81,12 +107,22 @@ class UnifiedLock(Generic[T]):
81
  try:
82
  if self._is_async:
83
  raise RuntimeError("Use 'async with' for shared_storage lock")
84
- direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)", enable_output=self._enable_logging)
 
 
 
85
  self._lock.acquire()
86
- direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)", enable_output=self._enable_logging)
 
 
 
87
  return self
88
  except Exception as e:
89
- direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging)
 
 
 
 
90
  raise
91
 
92
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -94,32 +130,62 @@ class UnifiedLock(Generic[T]):
94
  try:
95
  if self._is_async:
96
  raise RuntimeError("Use 'async with' for shared_storage lock")
97
- direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", enable_output=self._enable_logging)
 
 
 
98
  self._lock.release()
99
- direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)", enable_output=self._enable_logging)
 
 
 
100
  except Exception as e:
101
- direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging)
 
 
 
 
102
  raise
103
 
104
 
105
  def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
106
  """return unified storage lock for data consistency"""
107
- return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess, name="internal_lock", enable_logging=enable_logging)
 
 
 
 
 
108
 
109
 
110
  def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
111
  """return unified storage lock for data consistency"""
112
- return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess, name="storage_lock", enable_logging=enable_logging)
 
 
 
 
 
113
 
114
 
115
  def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
116
  """return unified storage lock for data consistency"""
117
- return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess, name="pipeline_status_lock", enable_logging=enable_logging)
 
 
 
 
 
118
 
119
 
120
  def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
121
  """return unified graph database lock for ensuring atomic operations"""
122
- return UnifiedLock(lock=_graph_db_lock, is_async=not is_multiprocess, name="graph_db_lock", enable_logging=enable_logging)
 
 
 
 
 
123
 
124
 
125
  def initialize_share_data(workers: int = 1):
 
11
  """
12
  Log a message directly to stderr to ensure visibility in all processes,
13
  including the Gunicorn master process.
14
+
15
  Args:
16
  message: The message to log
17
  level: Log level (default: "INFO")
 
44
  class UnifiedLock(Generic[T]):
45
  """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
46
 
47
+ def __init__(
48
+ self,
49
+ lock: Union[ProcessLock, asyncio.Lock],
50
+ is_async: bool,
51
+ name: str = "unnamed",
52
+ enable_logging: bool = True,
53
+ ):
54
  self._lock = lock
55
  self._is_async = is_async
56
  self._pid = os.getpid() # for debug only
 
59
 
60
  async def __aenter__(self) -> "UnifiedLock[T]":
61
  try:
62
+ direct_log(
63
+ f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
64
+ enable_output=self._enable_logging,
65
+ )
66
  if self._is_async:
67
  await self._lock.acquire()
68
  else:
69
  self._lock.acquire()
70
+ direct_log(
71
+ f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})",
72
+ enable_output=self._enable_logging,
73
+ )
74
  return self
75
  except Exception as e:
76
+ direct_log(
77
+ f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}",
78
+ level="ERROR",
79
+ enable_output=self._enable_logging,
80
+ )
81
  raise
82
 
83
  async def __aexit__(self, exc_type, exc_val, exc_tb):
84
  try:
85
+ direct_log(
86
+ f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
87
+ enable_output=self._enable_logging,
88
+ )
89
  if self._is_async:
90
  self._lock.release()
91
  else:
92
  self._lock.release()
93
+ direct_log(
94
+ f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})",
95
+ enable_output=self._enable_logging,
96
+ )
97
  except Exception as e:
98
+ direct_log(
99
+ f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}",
100
+ level="ERROR",
101
+ enable_output=self._enable_logging,
102
+ )
103
  raise
104
 
105
  def __enter__(self) -> "UnifiedLock[T]":
 
107
  try:
108
  if self._is_async:
109
  raise RuntimeError("Use 'async with' for shared_storage lock")
110
+ direct_log(
111
+ f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)",
112
+ enable_output=self._enable_logging,
113
+ )
114
  self._lock.acquire()
115
+ direct_log(
116
+ f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)",
117
+ enable_output=self._enable_logging,
118
+ )
119
  return self
120
  except Exception as e:
121
+ direct_log(
122
+ f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}",
123
+ level="ERROR",
124
+ enable_output=self._enable_logging,
125
+ )
126
  raise
127
 
128
  def __exit__(self, exc_type, exc_val, exc_tb):
 
130
  try:
131
  if self._is_async:
132
  raise RuntimeError("Use 'async with' for shared_storage lock")
133
+ direct_log(
134
+ f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)",
135
+ enable_output=self._enable_logging,
136
+ )
137
  self._lock.release()
138
+ direct_log(
139
+ f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)",
140
+ enable_output=self._enable_logging,
141
+ )
142
  except Exception as e:
143
+ direct_log(
144
+ f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}",
145
+ level="ERROR",
146
+ enable_output=self._enable_logging,
147
+ )
148
  raise
149
 
150
 
151
  def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
152
  """return unified storage lock for data consistency"""
153
+ return UnifiedLock(
154
+ lock=_internal_lock,
155
+ is_async=not is_multiprocess,
156
+ name="internal_lock",
157
+ enable_logging=enable_logging,
158
+ )
159
 
160
 
161
  def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
162
  """return unified storage lock for data consistency"""
163
+ return UnifiedLock(
164
+ lock=_storage_lock,
165
+ is_async=not is_multiprocess,
166
+ name="storage_lock",
167
+ enable_logging=enable_logging,
168
+ )
169
 
170
 
171
  def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
172
  """return unified storage lock for data consistency"""
173
+ return UnifiedLock(
174
+ lock=_pipeline_status_lock,
175
+ is_async=not is_multiprocess,
176
+ name="pipeline_status_lock",
177
+ enable_logging=enable_logging,
178
+ )
179
 
180
 
181
  def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
182
  """return unified graph database lock for ensuring atomic operations"""
183
+ return UnifiedLock(
184
+ lock=_graph_db_lock,
185
+ is_async=not is_multiprocess,
186
+ name="graph_db_lock",
187
+ enable_logging=enable_logging,
188
+ )
189
 
190
 
191
  def initialize_share_data(workers: int = 1):
lightrag/operate.py CHANGED
@@ -522,8 +522,9 @@ async def extract_entities(
522
  maybe_edges[tuple(sorted(k))].extend(v)
523
 
524
  from .kg.shared_storage import get_graph_db_lock
525
- graph_db_lock = get_graph_db_lock(enable_logging = True)
526
-
 
527
  # Ensure that nodes and edges are merged and upserted atomically
528
  async with graph_db_lock:
529
  all_entities_data = await asyncio.gather(
@@ -535,7 +536,9 @@ async def extract_entities(
535
 
536
  all_relationships_data = await asyncio.gather(
537
  *[
538
- _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
 
 
539
  for k, v in maybe_edges.items()
540
  ]
541
  )
 
522
  maybe_edges[tuple(sorted(k))].extend(v)
523
 
524
  from .kg.shared_storage import get_graph_db_lock
525
+
526
+ graph_db_lock = get_graph_db_lock(enable_logging=True)
527
+
528
  # Ensure that nodes and edges are merged and upserted atomically
529
  async with graph_db_lock:
530
  all_entities_data = await asyncio.gather(
 
536
 
537
  all_relationships_data = await asyncio.gather(
538
  *[
539
+ _merge_edges_then_upsert(
540
+ k[0], k[1], v, knowledge_graph_inst, global_config
541
+ )
542
  for k, v in maybe_edges.items()
543
  ]
544
  )