yangdx commited on
Commit
de70bba
·
1 Parent(s): e5f9f74

Refactor shared storage module to improve async handling and naming consistency

Browse files

• Add async support for get_namespace_data
• Rename get_update_flags to get_update_flag
• Rename set_update_flag to set_all_update_flags
• Update docstrings for clarity
• Fix typos in log messages

lightrag/api/routers/document_routes.py CHANGED
@@ -667,7 +667,7 @@ def create_document_routes(
667
  try:
668
  from lightrag.kg.shared_storage import get_namespace_data
669
 
670
- pipeline_status = get_namespace_data("pipeline_status")
671
 
672
  # Convert to regular dict if it's a Manager.dict
673
  status_dict = dict(pipeline_status)
 
667
  try:
668
  from lightrag.kg.shared_storage import get_namespace_data
669
 
670
+ pipeline_status = await get_namespace_data("pipeline_status")
671
 
672
  # Convert to regular dict if it's a Manager.dict
673
  status_dict = dict(pipeline_status)
lightrag/kg/shared_storage.py CHANGED
@@ -18,13 +18,12 @@ def direct_log(message, level="INFO"):
18
  T = TypeVar('T')
19
 
20
  class UnifiedLock(Generic[T]):
21
- """统一的锁包装类,提供同步和异步的统一接口"""
22
  def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
23
  self._lock = lock
24
  self._is_async = is_async
25
 
26
  async def __aenter__(self) -> 'UnifiedLock[T]':
27
- """异步上下文管理器入口"""
28
  if self._is_async:
29
  await self._lock.acquire()
30
  else:
@@ -32,21 +31,20 @@ class UnifiedLock(Generic[T]):
32
  return self
33
 
34
  async def __aexit__(self, exc_type, exc_val, exc_tb):
35
- """异步上下文管理器出口"""
36
  if self._is_async:
37
  self._lock.release()
38
  else:
39
  self._lock.release()
40
 
41
  def __enter__(self) -> 'UnifiedLock[T]':
42
- """同步上下文管理器入口(仅用于向后兼容)"""
43
  if self._is_async:
44
  raise RuntimeError("Use 'async with' for asyncio.Lock")
45
  self._lock.acquire()
46
  return self
47
 
48
  def __exit__(self, exc_type, exc_val, exc_tb):
49
- """同步上下文管理器出口(仅用于向后兼容)"""
50
  if self._is_async:
51
  raise RuntimeError("Use 'async with' for asyncio.Lock")
52
  self._lock.release()
@@ -153,10 +151,10 @@ async def initialize_pipeline_namespace():
153
  direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
154
 
155
 
156
- async def get_update_flags(namespace: str):
157
  """
158
- Create a updated flags of a specific namespace.
159
- Caller must store the dict object locally and use it to determine when to update the storage.
160
  """
161
  global _update_flags
162
  if _update_flags is None:
@@ -178,8 +176,8 @@ async def get_update_flags(namespace: str):
178
  _update_flags[namespace].append(new_update_flag)
179
  return new_update_flag
180
 
181
- async def set_update_flag(namespace: str):
182
- """Set all update flag of namespace to indicate storage needs updating"""
183
  global _update_flags
184
  if _update_flags is None:
185
  raise ValueError("Try to create namespace before Shared-Data is initialized")
@@ -212,7 +210,7 @@ def try_initialize_namespace(namespace: str) -> bool:
212
  )
213
  return True
214
  direct_log(
215
- f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]"
216
  )
217
  return False
218
 
 
18
  T = TypeVar('T')
19
 
20
  class UnifiedLock(Generic[T]):
21
+ """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
22
  def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
23
  self._lock = lock
24
  self._is_async = is_async
25
 
26
  async def __aenter__(self) -> 'UnifiedLock[T]':
 
27
  if self._is_async:
28
  await self._lock.acquire()
29
  else:
 
31
  return self
32
 
33
  async def __aexit__(self, exc_type, exc_val, exc_tb):
 
34
  if self._is_async:
35
  self._lock.release()
36
  else:
37
  self._lock.release()
38
 
39
  def __enter__(self) -> 'UnifiedLock[T]':
40
+ """For backward compatibility"""
41
  if self._is_async:
42
  raise RuntimeError("Use 'async with' for asyncio.Lock")
43
  self._lock.acquire()
44
  return self
45
 
46
  def __exit__(self, exc_type, exc_val, exc_tb):
47
+ """For backward compatibility"""
48
  if self._is_async:
49
  raise RuntimeError("Use 'async with' for asyncio.Lock")
50
  self._lock.release()
 
151
  direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
152
 
153
 
154
+ async def get_update_flag(namespace: str):
155
  """
156
+ Create a namespace's update flag for a workers.
157
+ Returen the update flag to caller for referencing or reset.
158
  """
159
  global _update_flags
160
  if _update_flags is None:
 
176
  _update_flags[namespace].append(new_update_flag)
177
  return new_update_flag
178
 
179
+ async def set_all_update_flags(namespace: str):
180
+ """Set all update flag of namespace indicating all workers need to reload data from files"""
181
  global _update_flags
182
  if _update_flags is None:
183
  raise ValueError("Try to create namespace before Shared-Data is initialized")
 
210
  )
211
  return True
212
  direct_log(
213
+ f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
214
  )
215
  return False
216
 
lightrag/operate.py CHANGED
@@ -338,7 +338,7 @@ async def extract_entities(
338
  ) -> None:
339
  from lightrag.kg.shared_storage import get_namespace_data
340
 
341
- pipeline_status = get_namespace_data("pipeline_status")
342
  use_llm_func: callable = global_config["llm_model_func"]
343
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
344
  enable_llm_cache_for_entity_extract: bool = global_config[
 
338
  ) -> None:
339
  from lightrag.kg.shared_storage import get_namespace_data
340
 
341
+ pipeline_status = await get_namespace_data("pipeline_status")
342
  use_llm_func: callable = global_config["llm_model_func"]
343
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
344
  enable_llm_cache_for_entity_extract: bool = global_config[