yangdx commited on
Commit
a02a230
·
1 Parent(s): 82c47ab

Ensure thread safety in storage update callbacks

Browse files

- Added storage lock in index_done_callback
- Fixed potential race conditions

lightrag/api/utils_api.py CHANGED
@@ -365,7 +365,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
365
  args.vector_storage = get_env_value(
366
  "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
367
  )
368
-
369
  # Get MAX_PARALLEL_INSERT from environment
370
  args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
371
 
@@ -397,7 +397,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
397
  args.enable_llm_cache_for_extract = get_env_value(
398
  "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
399
  )
400
-
401
  # Inject LLM temperature configuration
402
  args.temperature = get_env_value("TEMPERATURE", 0.5, float)
403
 
 
365
  args.vector_storage = get_env_value(
366
  "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
367
  )
368
+
369
  # Get MAX_PARALLEL_INSERT from environment
370
  args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
371
 
 
397
  args.enable_llm_cache_for_extract = get_env_value(
398
  "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
399
  )
400
+
401
  # Inject LLM temperature configuration
402
  args.temperature = get_env_value("TEMPERATURE", 0.5, float)
403
 
lightrag/kg/faiss_impl.py CHANGED
@@ -343,18 +343,19 @@ class FaissVectorDBStorage(BaseVectorStorage):
343
  self._id_to_meta = {}
344
 
345
  async def index_done_callback(self) -> None:
346
- # Check if storage was updated by another process
347
- if is_multiprocess and self.storage_updated.value:
348
- # Storage was updated by another process, reload data instead of saving
349
- logger.warning(
350
- f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
351
- )
352
- async with self._storage_lock:
353
- self._index = faiss.IndexFlatIP(self._dim)
354
- self._id_to_meta = {}
355
- self._load_faiss_index()
356
- self.storage_updated.value = False
357
- return False # Return error
 
358
 
359
  # Acquire lock and perform persistence
360
  async with self._storage_lock:
 
343
  self._id_to_meta = {}
344
 
345
  async def index_done_callback(self) -> None:
346
+ async with self._storage_lock:
347
+ # Check if storage was updated by another process
348
+ if is_multiprocess and self.storage_updated.value:
349
+ # Storage was updated by another process, reload data instead of saving
350
+ logger.warning(
351
+ f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
352
+ )
353
+ async with self._storage_lock:
354
+ self._index = faiss.IndexFlatIP(self._dim)
355
+ self._id_to_meta = {}
356
+ self._load_faiss_index()
357
+ self.storage_updated.value = False
358
+ return False # Return error
359
 
360
  # Acquire lock and perform persistence
361
  async with self._storage_lock:
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -206,19 +206,20 @@ class NanoVectorDBStorage(BaseVectorStorage):
206
 
207
  async def index_done_callback(self) -> bool:
208
  """Save data to disk"""
209
- # Check if storage was updated by another process
210
- if is_multiprocess and self.storage_updated.value:
211
- # Storage was updated by another process, reload data instead of saving
212
- logger.warning(
213
- f"Storage for {self.namespace} was updated by another process, reloading..."
214
- )
215
- self._client = NanoVectorDB(
216
- self.embedding_func.embedding_dim,
217
- storage_file=self._client_file_name,
218
- )
219
- # Reset update flag
220
- self.storage_updated.value = False
221
- return False # Return error
 
222
 
223
  # Acquire lock and perform persistence
224
  async with self._storage_lock:
 
206
 
207
  async def index_done_callback(self) -> bool:
208
  """Save data to disk"""
209
+ async with self._storage_lock:
210
+ # Check if storage was updated by another process
211
+ if is_multiprocess and self.storage_updated.value:
212
+ # Storage was updated by another process, reload data instead of saving
213
+ logger.warning(
214
+ f"Storage for {self.namespace} was updated by another process, reloading..."
215
+ )
216
+ self._client = NanoVectorDB(
217
+ self.embedding_func.embedding_dim,
218
+ storage_file=self._client_file_name,
219
+ )
220
+ # Reset update flag
221
+ self.storage_updated.value = False
222
+ return False # Return error
223
 
224
  # Acquire lock and perform persistence
225
  async with self._storage_lock:
lightrag/kg/networkx_impl.py CHANGED
@@ -401,18 +401,19 @@ class NetworkXStorage(BaseGraphStorage):
401
 
402
  async def index_done_callback(self) -> bool:
403
  """Save data to disk"""
404
- # Check if storage was updated by another process
405
- if is_multiprocess and self.storage_updated.value:
406
- # Storage was updated by another process, reload data instead of saving
407
- logger.warning(
408
- f"Graph for {self.namespace} was updated by another process, reloading..."
409
- )
410
- self._graph = (
411
- NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
412
- )
413
- # Reset update flag
414
- self.storage_updated.value = False
415
- return False # Return error
 
416
 
417
  # Acquire lock and perform persistence
418
  async with self._storage_lock:
 
401
 
402
  async def index_done_callback(self) -> bool:
403
  """Save data to disk"""
404
+ async with self._storage_lock:
405
+ # Check if storage was updated by another process
406
+ if is_multiprocess and self.storage_updated.value:
407
+ # Storage was updated by another process, reload data instead of saving
408
+ logger.warning(
409
+ f"Graph for {self.namespace} was updated by another process, reloading..."
410
+ )
411
+ self._graph = (
412
+ NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
413
+ )
414
+ # Reset update flag
415
+ self.storage_updated.value = False
416
+ return False # Return error
417
 
418
  # Acquire lock and perform persistence
419
  async with self._storage_lock: