yangdx commited on
Commit
f27d730
·
1 Parent(s): 5fb46b4

Add async support and update flag mechanism for shared storage

Browse files

• Use asyncio.Lock instead of thread lock for single process mode
• Add storage update notification system

Files changed (1) hide show
  1. lightrag/kg/shared_storage.py +73 -13
lightrag/kg/shared_storage.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import sys
 
3
  from multiprocessing.synchronize import Lock as ProcessLock
4
- from threading import Lock as ThreadLock
5
  from multiprocessing import Manager
6
  from typing import Any, Dict, Optional, Union
7
 
@@ -15,16 +15,18 @@ def direct_log(message, level="INFO"):
15
  print(f"{level}: {message}", file=sys.stderr, flush=True)
16
 
17
 
18
- LockType = Union[ProcessLock, ThreadLock]
19
 
 
 
20
  _manager = None
21
  _initialized = None
22
- is_multiprocess = None
23
  _global_lock: Optional[LockType] = None
24
 
25
  # shared data for storage across processes
26
  _shared_dicts: Optional[Dict[str, Any]] = None
27
  _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
 
28
 
29
 
30
  def initialize_share_data(workers: int = 1):
@@ -47,12 +49,14 @@ def initialize_share_data(workers: int = 1):
47
  """
48
  global \
49
  _manager, \
 
50
  is_multiprocess, \
51
  is_multiprocess, \
52
  _global_lock, \
53
  _shared_dicts, \
54
  _init_flags, \
55
- _initialized
 
56
 
57
  # Check if already initialized
58
  if _initialized:
@@ -62,20 +66,23 @@ def initialize_share_data(workers: int = 1):
62
  return
63
 
64
  _manager = Manager()
 
65
 
66
  if workers > 1:
67
  is_multiprocess = True
68
  _global_lock = _manager.Lock()
69
  _shared_dicts = _manager.dict()
70
  _init_flags = _manager.dict()
 
71
  direct_log(
72
  f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
73
  )
74
  else:
75
  is_multiprocess = False
76
- _global_lock = ThreadLock()
77
  _shared_dicts = {}
78
  _init_flags = {}
 
79
  direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
80
 
81
  # Mark as initialized
@@ -86,7 +93,6 @@ def initialize_share_data(workers: int = 1):
86
 
87
  # Create a shared list object for history_messages
88
  history_messages = _manager.list() if is_multiprocess else []
89
-
90
  pipeline_namespace.update(
91
  {
92
  "busy": False, # Control concurrent processes
@@ -102,6 +108,58 @@ def initialize_share_data(workers: int = 1):
102
  )
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def try_initialize_namespace(namespace: str) -> bool:
106
  """
107
  Try to initialize a namespace. Returns True if the current process gets initialization permission.
@@ -129,7 +187,7 @@ def get_storage_lock() -> LockType:
129
  return _global_lock
130
 
131
 
132
- def get_namespace_data(namespace: str) -> Dict[str, Any]:
133
  """get storage space for specific storage type(namespace)"""
134
  if _shared_dicts is None:
135
  direct_log(
@@ -138,12 +196,14 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]:
138
  )
139
  raise ValueError("Shared dictionaries not initialized")
140
 
141
- lock = get_storage_lock()
142
- with lock:
143
- if namespace not in _shared_dicts:
144
- if is_multiprocess and _manager is not None:
145
- _shared_dicts[namespace] = _manager.dict()
146
- else:
 
 
147
  _shared_dicts[namespace] = {}
148
 
149
  return _shared_dicts[namespace]
 
1
  import os
2
  import sys
3
+ import asyncio
4
  from multiprocessing.synchronize import Lock as ProcessLock
 
5
  from multiprocessing import Manager
6
  from typing import Any, Dict, Optional, Union
7
 
 
15
  print(f"{level}: {message}", file=sys.stderr, flush=True)
16
 
17
 
18
+ LockType = Union[ProcessLock, asyncio.Lock]
19
 
20
+ is_multiprocess = None
21
+ _workers = None
22
  _manager = None
23
  _initialized = None
 
24
  _global_lock: Optional[LockType] = None
25
 
26
  # shared data for storage across processes
27
  _shared_dicts: Optional[Dict[str, Any]] = None
28
  _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
29
+ _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
30
 
31
 
32
  def initialize_share_data(workers: int = 1):
 
49
  """
50
  global \
51
  _manager, \
52
+ _workers, \
53
  is_multiprocess, \
54
  is_multiprocess, \
55
  _global_lock, \
56
  _shared_dicts, \
57
  _init_flags, \
58
+ _initialized, \
59
+ _update_flags
60
 
61
  # Check if already initialized
62
  if _initialized:
 
66
  return
67
 
68
  _manager = Manager()
69
+ _workers = workers
70
 
71
  if workers > 1:
72
  is_multiprocess = True
73
  _global_lock = _manager.Lock()
74
  _shared_dicts = _manager.dict()
75
  _init_flags = _manager.dict()
76
+ _update_flags = _manager.dict()
77
  direct_log(
78
  f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
79
  )
80
  else:
81
  is_multiprocess = False
82
+ _global_lock = asyncio.Lock()
83
  _shared_dicts = {}
84
  _init_flags = {}
85
+ _update_flags = {}
86
  direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
87
 
88
  # Mark as initialized
 
93
 
94
  # Create a shared list object for history_messages
95
  history_messages = _manager.list() if is_multiprocess else []
 
96
  pipeline_namespace.update(
97
  {
98
  "busy": False, # Control concurrent processes
 
108
  )
109
 
110
 
111
+ async def get_update_flags(namespace: str):
112
+ """
113
+ Create a updated flags of a specific namespace.
114
+ Caller must store the dict object locally and use it to determine when to update the storage.
115
+ """
116
+ global _update_flags
117
+ if _update_flags is None:
118
+ raise ValueError("Try to create namespace before Shared-Data is initialized")
119
+
120
+ if is_multiprocess:
121
+ with _global_lock:
122
+ if namespace not in _update_flags:
123
+ if _manager is not None:
124
+ _update_flags[namespace] = _manager.list()
125
+ direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]")
126
+
127
+ if _manager is not None:
128
+ new_update_flag = _manager.Value('b', False)
129
+ _update_flags[namespace].append(new_update_flag)
130
+ return new_update_flag
131
+ else:
132
+ async with _global_lock:
133
+ if namespace not in _update_flags:
134
+ _update_flags[namespace] = []
135
+ direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]")
136
+
137
+ new_update_flag = False
138
+ _update_flags[namespace].append(new_update_flag)
139
+ return new_update_flag
140
+
141
+ async def set_update_flag(namespace: str):
142
+ """Set all update flag of namespace to indicate storage needs updating"""
143
+ global _update_flags
144
+ if _update_flags is None:
145
+ raise ValueError("Try to create namespace before Shared-Data is initialized")
146
+
147
+ if is_multiprocess:
148
+ with _global_lock:
149
+ if namespace not in _update_flags:
150
+ raise ValueError(f"Namespace {namespace} not found in update flags")
151
+ # Update flags for multiprocess mode
152
+ for i in range(len(_update_flags[namespace])):
153
+ _update_flags[namespace][i].value = True
154
+ else:
155
+ async with _global_lock:
156
+ if namespace not in _update_flags:
157
+ raise ValueError(f"Namespace {namespace} not found in update flags")
158
+ # Update flags for single process mode
159
+ for i in range(len(_update_flags[namespace])):
160
+ _update_flags[namespace][i] = True
161
+
162
+
163
  def try_initialize_namespace(namespace: str) -> bool:
164
  """
165
  Try to initialize a namespace. Returns True if the current process gets initialization permission.
 
187
  return _global_lock
188
 
189
 
190
+ async def get_namespace_data(namespace: str) -> Dict[str, Any]:
191
  """get storage space for specific storage type(namespace)"""
192
  if _shared_dicts is None:
193
  direct_log(
 
196
  )
197
  raise ValueError("Shared dictionaries not initialized")
198
 
199
+ if is_multiprocess:
200
+ with _global_lock:
201
+ if namespace not in _shared_dicts:
202
+ if _manager is not None:
203
+ _shared_dicts[namespace] = _manager.dict()
204
+ else:
205
+ async with _global_lock:
206
+ if namespace not in _shared_dicts:
207
  _shared_dicts[namespace] = {}
208
 
209
  return _shared_dicts[namespace]