yangdx commited on
Commit
8a706cb
·
1 Parent(s): 911c794

revert vector and graph use local data(single process)

Browse files
lightrag/kg/faiss_impl.py CHANGED
@@ -10,19 +10,12 @@ import pipmaster as pm
10
 
11
  from lightrag.utils import logger, compute_mdhash_id
12
  from lightrag.base import BaseVectorStorage
13
- from .shared_storage import (
14
- get_namespace_data,
15
- get_storage_lock,
16
- get_namespace_object,
17
- is_multiprocess,
18
- try_initialize_namespace,
19
- )
20
 
21
  if not pm.is_installed("faiss"):
22
  pm.install("faiss")
23
 
24
  import faiss # type: ignore
25
-
26
 
27
  @final
28
  @dataclass
@@ -51,35 +44,29 @@ class FaissVectorDBStorage(BaseVectorStorage):
51
  self._max_batch_size = self.global_config["embedding_batch_num"]
52
  # Embedding dimension (e.g. 768) must match your embedding function
53
  self._dim = self.embedding_func.embedding_dim
54
- self._storage_lock = get_storage_lock()
55
-
56
- # check need_init must before get_namespace_object/get_namespace_data
57
- need_init = try_initialize_namespace("faiss_indices")
58
- self._index = get_namespace_object("faiss_indices")
59
- self._id_to_meta = get_namespace_data("faiss_meta")
60
-
61
- if need_init:
62
- if is_multiprocess:
63
- # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
64
- # If you have a large number of vectors, you might want IVF or other indexes.
65
- # For demonstration, we use a simple IndexFlatIP.
66
- self._index.value = faiss.IndexFlatIP(self._dim)
67
- # Keep a local store for metadata, IDs, etc.
68
- # Maps <int faiss_id> → metadata (including your original ID).
69
- self._id_to_meta.update({})
70
- # Attempt to load an existing index + metadata from disk
71
- self._load_faiss_index()
72
- else:
73
- self._index = faiss.IndexFlatIP(self._dim)
74
- self._id_to_meta.update({})
75
- self._load_faiss_index()
76
 
77
  def _get_index(self):
78
- """
79
- Helper method to get the correct index object based on multiprocess mode.
80
- Returns the actual index object that can be used for operations.
81
- """
82
- return self._index.value if is_multiprocess else self._index
 
83
 
84
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
85
  """
@@ -134,34 +121,33 @@ class FaissVectorDBStorage(BaseVectorStorage):
134
  # Normalize embeddings for cosine similarity (in-place)
135
  faiss.normalize_L2(embeddings)
136
 
137
- with self._storage_lock:
138
- # Upsert logic:
139
- # 1. Identify which vectors to remove if they exist
140
- # 2. Remove them
141
- # 3. Add the new vectors
142
- existing_ids_to_remove = []
143
- for meta, emb in zip(list_data, embeddings):
144
- faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
145
- if faiss_internal_id is not None:
146
- existing_ids_to_remove.append(faiss_internal_id)
147
-
148
- if existing_ids_to_remove:
149
- self._remove_faiss_ids(existing_ids_to_remove)
150
-
151
- # Step 2: Add new vectors
152
- index = self._get_index()
153
- start_idx = index.ntotal
154
- index.add(embeddings)
155
-
156
- # Step 3: Store metadata + vector for each new ID
157
- for i, meta in enumerate(list_data):
158
- fid = start_idx + i
159
- # Store the raw vector so we can rebuild if something is removed
160
- meta["__vector__"] = embeddings[i].tolist()
161
- self._id_to_meta.update({fid: meta})
162
-
163
- logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
164
- return [m["__id__"] for m in list_data]
165
 
166
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
167
  """
@@ -177,57 +163,54 @@ class FaissVectorDBStorage(BaseVectorStorage):
177
  )
178
 
179
  # Perform the similarity search
180
- with self._storage_lock:
181
- distances, indices = self._get_index().search(embedding, top_k)
182
-
183
- distances = distances[0]
184
- indices = indices[0]
185
-
186
- results = []
187
- for dist, idx in zip(distances, indices):
188
- if idx == -1:
189
- # Faiss returns -1 if no neighbor
190
- continue
191
-
192
- # Cosine similarity threshold
193
- if dist < self.cosine_better_than_threshold:
194
- continue
195
-
196
- meta = self._id_to_meta.get(idx, {})
197
- results.append(
198
- {
199
- **meta,
200
- "id": meta.get("__id__"),
201
- "distance": float(dist),
202
- "created_at": meta.get("__created_at__"),
203
- }
204
- )
205
 
206
- return results
207
 
208
  @property
209
  def client_storage(self):
210
  # Return whatever structure LightRAG might need for debugging
211
- with self._storage_lock:
212
- return {"data": list(self._id_to_meta.values())}
213
 
214
  async def delete(self, ids: list[str]):
215
  """
216
  Delete vectors for the provided custom IDs.
217
  """
218
  logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
219
- with self._storage_lock:
220
- to_remove = []
221
- for cid in ids:
222
- fid = self._find_faiss_id_by_custom_id(cid)
223
- if fid is not None:
224
- to_remove.append(fid)
225
-
226
- if to_remove:
227
- self._remove_faiss_ids(to_remove)
228
- logger.debug(
229
- f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
230
- )
231
 
232
  async def delete_entity(self, entity_name: str) -> None:
233
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
@@ -239,23 +222,18 @@ class FaissVectorDBStorage(BaseVectorStorage):
239
  Delete relations for a given entity by scanning metadata.
240
  """
241
  logger.debug(f"Searching relations for entity {entity_name}")
242
- with self._storage_lock:
243
- relations = []
244
- for fid, meta in self._id_to_meta.items():
245
- if (
246
- meta.get("src_id") == entity_name
247
- or meta.get("tgt_id") == entity_name
248
- ):
249
- relations.append(fid)
250
-
251
- logger.debug(f"Found {len(relations)} relations for {entity_name}")
252
- if relations:
253
- self._remove_faiss_ids(relations)
254
- logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
255
-
256
- async def index_done_callback(self) -> None:
257
- with self._storage_lock:
258
- self._save_faiss_index()
259
 
260
  # --------------------------------------------------------------------------------
261
  # Internal helper methods
@@ -265,11 +243,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
265
  """
266
  Return the Faiss internal ID for a given custom ID, or None if not found.
267
  """
268
- with self._storage_lock:
269
- for fid, meta in self._id_to_meta.items():
270
- if meta.get("__id__") == custom_id:
271
- return fid
272
- return None
273
 
274
  def _remove_faiss_ids(self, fid_list):
275
  """
@@ -277,48 +254,42 @@ class FaissVectorDBStorage(BaseVectorStorage):
277
  Because IndexFlatIP doesn't support 'removals',
278
  we rebuild the index excluding those vectors.
279
  """
 
 
 
 
 
 
 
 
 
 
280
  with self._storage_lock:
281
- keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
282
-
283
- # Rebuild the index
284
- vectors_to_keep = []
285
- new_id_to_meta = {}
286
- for new_fid, old_fid in enumerate(keep_fids):
287
- vec_meta = self._id_to_meta[old_fid]
288
- vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
289
- new_id_to_meta[new_fid] = vec_meta
290
-
291
- # Re-init index
292
- new_index = faiss.IndexFlatIP(self._dim)
293
  if vectors_to_keep:
294
  arr = np.array(vectors_to_keep, dtype=np.float32)
295
- new_index.add(arr)
296
- if is_multiprocess:
297
- self._index.value = new_index
298
- else:
299
- self._index = new_index
300
 
301
- self._id_to_meta.update(new_id_to_meta)
302
 
303
  def _save_faiss_index(self):
304
  """
305
  Save the current Faiss index + metadata to disk so it can persist across runs.
306
  """
307
- with self._storage_lock:
308
- faiss.write_index(
309
- self._get_index(),
310
- self._faiss_index_file,
311
- )
312
 
313
- # Save metadata dict to JSON. Convert all keys to strings for JSON storage.
314
- # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
315
- # We'll keep the int -> dict, but JSON requires string keys.
316
- serializable_dict = {}
317
- for fid, meta in self._id_to_meta.items():
318
- serializable_dict[str(fid)] = meta
 
 
 
319
 
320
- with open(self._meta_file, "w", encoding="utf-8") as f:
321
- json.dump(serializable_dict, f)
322
 
323
  def _load_faiss_index(self):
324
  """
@@ -331,31 +302,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
331
 
332
  try:
333
  # Load the Faiss index
334
- loaded_index = faiss.read_index(self._faiss_index_file)
335
- if is_multiprocess:
336
- self._index.value = loaded_index
337
- else:
338
- self._index = loaded_index
339
-
340
  # Load metadata
341
  with open(self._meta_file, "r", encoding="utf-8") as f:
342
  stored_dict = json.load(f)
343
 
344
  # Convert string keys back to int
345
- self._id_to_meta.update({})
346
  for fid_str, meta in stored_dict.items():
347
  fid = int(fid_str)
348
  self._id_to_meta[fid] = meta
349
 
350
  logger.info(
351
- f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}"
352
  )
353
  except Exception as e:
354
  logger.error(f"Failed to load Faiss index or metadata: {e}")
355
  logger.warning("Starting with an empty Faiss index.")
356
- new_index = faiss.IndexFlatIP(self._dim)
357
- if is_multiprocess:
358
- self._index.value = new_index
359
- else:
360
- self._index = new_index
361
- self._id_to_meta.update({})
 
10
 
11
  from lightrag.utils import logger, compute_mdhash_id
12
  from lightrag.base import BaseVectorStorage
 
 
 
 
 
 
 
13
 
14
  if not pm.is_installed("faiss"):
15
  pm.install("faiss")
16
 
17
  import faiss # type: ignore
18
+ from threading import Lock as ThreadLock
19
 
20
  @final
21
  @dataclass
 
44
  self._max_batch_size = self.global_config["embedding_batch_num"]
45
  # Embedding dimension (e.g. 768) must match your embedding function
46
  self._dim = self.embedding_func.embedding_dim
47
+ self._storage_lock = ThreadLock()
48
+
49
+ # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
50
+ # If you have a large number of vectors, you might want IVF or other indexes.
51
+ # For demonstration, we use a simple IndexFlatIP.
52
+ self._index = faiss.IndexFlatIP(self._dim)
53
+
54
+ # Keep a local store for metadata, IDs, etc.
55
+ # Maps <int faiss_id> → metadata (including your original ID).
56
+ self._id_to_meta = {}
57
+
58
+ # Attempt to load an existing index + metadata from disk
59
+ with self._storage_lock:
60
+ self._load_faiss_index()
61
+
 
 
 
 
 
 
 
62
 
63
  def _get_index(self):
64
+ """Check if the shtorage should be reloaded"""
65
+ return self._index
66
+
67
+ async def index_done_callback(self) -> None:
68
+ with self._storage_lock:
69
+ self._save_faiss_index()
70
 
71
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
72
  """
 
121
  # Normalize embeddings for cosine similarity (in-place)
122
  faiss.normalize_L2(embeddings)
123
 
124
+ # Upsert logic:
125
+ # 1. Identify which vectors to remove if they exist
126
+ # 2. Remove them
127
+ # 3. Add the new vectors
128
+ existing_ids_to_remove = []
129
+ for meta, emb in zip(list_data, embeddings):
130
+ faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
131
+ if faiss_internal_id is not None:
132
+ existing_ids_to_remove.append(faiss_internal_id)
133
+
134
+ if existing_ids_to_remove:
135
+ self._remove_faiss_ids(existing_ids_to_remove)
136
+
137
+ # Step 2: Add new vectors
138
+ index = self._get_index()
139
+ start_idx = index.ntotal
140
+ index.add(embeddings)
141
+
142
+ # Step 3: Store metadata + vector for each new ID
143
+ for i, meta in enumerate(list_data):
144
+ fid = start_idx + i
145
+ # Store the raw vector so we can rebuild if something is removed
146
+ meta["__vector__"] = embeddings[i].tolist()
147
+ self._id_to_meta.update({fid: meta})
148
+
149
+ logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
150
+ return [m["__id__"] for m in list_data]
 
151
 
152
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
153
  """
 
163
  )
164
 
165
  # Perform the similarity search
166
+ distances, indices = self._get_index().search(embedding, top_k)
167
+
168
+ distances = distances[0]
169
+ indices = indices[0]
170
+
171
+ results = []
172
+ for dist, idx in zip(distances, indices):
173
+ if idx == -1:
174
+ # Faiss returns -1 if no neighbor
175
+ continue
176
+
177
+ # Cosine similarity threshold
178
+ if dist < self.cosine_better_than_threshold:
179
+ continue
180
+
181
+ meta = self._id_to_meta.get(idx, {})
182
+ results.append(
183
+ {
184
+ **meta,
185
+ "id": meta.get("__id__"),
186
+ "distance": float(dist),
187
+ "created_at": meta.get("__created_at__"),
188
+ }
189
+ )
 
190
 
191
+ return results
192
 
193
  @property
194
  def client_storage(self):
195
  # Return whatever structure LightRAG might need for debugging
196
+ return {"data": list(self._id_to_meta.values())}
 
197
 
198
  async def delete(self, ids: list[str]):
199
  """
200
  Delete vectors for the provided custom IDs.
201
  """
202
  logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
203
+ to_remove = []
204
+ for cid in ids:
205
+ fid = self._find_faiss_id_by_custom_id(cid)
206
+ if fid is not None:
207
+ to_remove.append(fid)
208
+
209
+ if to_remove:
210
+ self._remove_faiss_ids(to_remove)
211
+ logger.debug(
212
+ f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
213
+ )
 
214
 
215
  async def delete_entity(self, entity_name: str) -> None:
216
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
 
222
  Delete relations for a given entity by scanning metadata.
223
  """
224
  logger.debug(f"Searching relations for entity {entity_name}")
225
+ relations = []
226
+ for fid, meta in self._id_to_meta.items():
227
+ if (
228
+ meta.get("src_id") == entity_name
229
+ or meta.get("tgt_id") == entity_name
230
+ ):
231
+ relations.append(fid)
232
+
233
+ logger.debug(f"Found {len(relations)} relations for {entity_name}")
234
+ if relations:
235
+ self._remove_faiss_ids(relations)
236
+ logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
 
 
 
 
 
237
 
238
  # --------------------------------------------------------------------------------
239
  # Internal helper methods
 
243
  """
244
  Return the Faiss internal ID for a given custom ID, or None if not found.
245
  """
246
+ for fid, meta in self._id_to_meta.items():
247
+ if meta.get("__id__") == custom_id:
248
+ return fid
249
+ return None
 
250
 
251
  def _remove_faiss_ids(self, fid_list):
252
  """
 
254
  Because IndexFlatIP doesn't support 'removals',
255
  we rebuild the index excluding those vectors.
256
  """
257
+ keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
258
+
259
+ # Rebuild the index
260
+ vectors_to_keep = []
261
+ new_id_to_meta = {}
262
+ for new_fid, old_fid in enumerate(keep_fids):
263
+ vec_meta = self._id_to_meta[old_fid]
264
+ vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
265
+ new_id_to_meta[new_fid] = vec_meta
266
+
267
  with self._storage_lock:
268
+ # Re-init index
269
+ self._index = faiss.IndexFlatIP(self._dim)
 
 
 
 
 
 
 
 
 
 
270
  if vectors_to_keep:
271
  arr = np.array(vectors_to_keep, dtype=np.float32)
272
+ self._index.add(arr)
273
+
274
+ self._id_to_meta = new_id_to_meta
 
 
275
 
 
276
 
277
  def _save_faiss_index(self):
278
  """
279
  Save the current Faiss index + metadata to disk so it can persist across runs.
280
  """
281
+ faiss.write_index(self._index, self._faiss_index_file)
 
 
 
 
282
 
283
+ # Save metadata dict to JSON. Convert all keys to strings for JSON storage.
284
+ # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
285
+ # We'll keep the int -> dict, but JSON requires string keys.
286
+ serializable_dict = {}
287
+ for fid, meta in self._id_to_meta.items():
288
+ serializable_dict[str(fid)] = meta
289
+
290
+ with open(self._meta_file, "w", encoding="utf-8") as f:
291
+ json.dump(serializable_dict, f)
292
 
 
 
293
 
294
  def _load_faiss_index(self):
295
  """
 
302
 
303
  try:
304
  # Load the Faiss index
305
+ self._index = faiss.read_index(self._faiss_index_file)
 
 
 
 
 
306
  # Load metadata
307
  with open(self._meta_file, "r", encoding="utf-8") as f:
308
  stored_dict = json.load(f)
309
 
310
  # Convert string keys back to int
311
+ self._id_to_meta = {}
312
  for fid_str, meta in stored_dict.items():
313
  fid = int(fid_str)
314
  self._id_to_meta[fid] = meta
315
 
316
  logger.info(
317
+ f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
318
  )
319
  except Exception as e:
320
  logger.error(f"Failed to load Faiss index or metadata: {e}")
321
  logger.warning("Starting with an empty Faiss index.")
322
+ self._index = faiss.IndexFlatIP(self._dim)
323
+ self._id_to_meta = {}
 
 
 
 
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -11,25 +11,19 @@ from lightrag.utils import (
11
  )
12
  import pipmaster as pm
13
  from lightrag.base import BaseVectorStorage
14
- from .shared_storage import (
15
- get_storage_lock,
16
- get_namespace_object,
17
- is_multiprocess,
18
- try_initialize_namespace,
19
- )
20
 
21
  if not pm.is_installed("nano-vectordb"):
22
  pm.install("nano-vectordb")
23
 
24
  from nano_vectordb import NanoVectorDB
25
-
26
 
27
  @final
28
  @dataclass
29
  class NanoVectorDBStorage(BaseVectorStorage):
30
  def __post_init__(self):
31
  # Initialize lock only for file operations
32
- self._storage_lock = get_storage_lock()
33
 
34
  # Use global config value if specified, otherwise use default
35
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -45,32 +39,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
45
  )
46
  self._max_batch_size = self.global_config["embedding_batch_num"]
47
 
48
- # check need_init must before get_namespace_object
49
- need_init = try_initialize_namespace(self.namespace)
50
- self._client = get_namespace_object(self.namespace)
51
-
52
- if need_init:
53
- if is_multiprocess:
54
- self._client.value = NanoVectorDB(
55
- self.embedding_func.embedding_dim,
56
- storage_file=self._client_file_name,
57
- )
58
- logger.info(
59
- f"Initialized vector DB client for namespace {self.namespace}"
60
- )
61
- else:
62
- self._client = NanoVectorDB(
63
- self.embedding_func.embedding_dim,
64
- storage_file=self._client_file_name,
65
- )
66
- logger.info(
67
- f"Initialized vector DB client for namespace {self.namespace}"
68
- )
69
 
70
  def _get_client(self):
71
- """Get the appropriate client instance based on multiprocess mode"""
72
- if is_multiprocess:
73
- return self._client.value
74
  return self._client
75
 
76
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
@@ -101,8 +77,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
101
  if len(embeddings) == len(list_data):
102
  for i, d in enumerate(list_data):
103
  d["__vector__"] = embeddings[i]
104
- with self._storage_lock:
105
- results = self._get_client().upsert(datas=list_data)
106
  return results
107
  else:
108
  # sometimes the embedding is not returned correctly. just log it.
@@ -115,21 +90,20 @@ class NanoVectorDBStorage(BaseVectorStorage):
115
  embedding = await self.embedding_func([query])
116
  embedding = embedding[0]
117
 
118
- with self._storage_lock:
119
- results = self._get_client().query(
120
- query=embedding,
121
- top_k=top_k,
122
- better_than_threshold=self.cosine_better_than_threshold,
123
- )
124
- results = [
125
- {
126
- **dp,
127
- "id": dp["__id__"],
128
- "distance": dp["__metrics__"],
129
- "created_at": dp.get("__created_at__"),
130
- }
131
- for dp in results
132
- ]
133
  return results
134
 
135
  @property
@@ -143,8 +117,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
143
  ids: List of vector IDs to be deleted
144
  """
145
  try:
146
- with self._storage_lock:
147
- self._get_client().delete(ids)
148
  logger.debug(
149
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
150
  )
@@ -158,37 +131,35 @@ class NanoVectorDBStorage(BaseVectorStorage):
158
  f"Attempting to delete entity {entity_name} with ID {entity_id}"
159
  )
160
 
161
- with self._storage_lock:
162
- # Check if the entity exists
163
- if self._get_client().get([entity_id]):
164
- self._get_client().delete([entity_id])
165
- logger.debug(f"Successfully deleted entity {entity_name}")
166
- else:
167
- logger.debug(f"Entity {entity_name} not found in storage")
168
  except Exception as e:
169
  logger.error(f"Error deleting entity {entity_name}: {e}")
170
 
171
  async def delete_entity_relation(self, entity_name: str) -> None:
172
  try:
173
- with self._storage_lock:
174
- storage = getattr(self._get_client(), "_NanoVectorDB__storage")
175
- relations = [
176
- dp
177
- for dp in storage["data"]
178
- if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
179
- ]
 
 
 
 
 
 
180
  logger.debug(
181
- f"Found {len(relations)} relations for entity {entity_name}"
182
  )
183
- ids_to_delete = [relation["__id__"] for relation in relations]
184
-
185
- if ids_to_delete:
186
- self._get_client().delete(ids_to_delete)
187
- logger.debug(
188
- f"Deleted {len(ids_to_delete)} relations for {entity_name}"
189
- )
190
- else:
191
- logger.debug(f"No relations found for entity {entity_name}")
192
  except Exception as e:
193
  logger.error(f"Error deleting relations for {entity_name}: {e}")
194
 
 
11
  )
12
  import pipmaster as pm
13
  from lightrag.base import BaseVectorStorage
 
 
 
 
 
 
14
 
15
  if not pm.is_installed("nano-vectordb"):
16
  pm.install("nano-vectordb")
17
 
18
  from nano_vectordb import NanoVectorDB
19
+ from threading import Lock as ThreadLock
20
 
21
  @final
22
  @dataclass
23
  class NanoVectorDBStorage(BaseVectorStorage):
24
  def __post_init__(self):
25
  # Initialize lock only for file operations
26
+ self._storage_lock = ThreadLock()
27
 
28
  # Use global config value if specified, otherwise use default
29
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
 
39
  )
40
  self._max_batch_size = self.global_config["embedding_batch_num"]
41
 
42
+ with self._storage_lock:
43
+ self._client = NanoVectorDB(
44
+ self.embedding_func.embedding_dim,
45
+ storage_file=self._client_file_name,
46
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def _get_client(self):
49
+ """Check if the shtorage should be reloaded"""
 
 
50
  return self._client
51
 
52
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
77
  if len(embeddings) == len(list_data):
78
  for i, d in enumerate(list_data):
79
  d["__vector__"] = embeddings[i]
80
+ results = self._get_client().upsert(datas=list_data)
 
81
  return results
82
  else:
83
  # sometimes the embedding is not returned correctly. just log it.
 
90
  embedding = await self.embedding_func([query])
91
  embedding = embedding[0]
92
 
93
+ results = self._get_client().query(
94
+ query=embedding,
95
+ top_k=top_k,
96
+ better_than_threshold=self.cosine_better_than_threshold,
97
+ )
98
+ results = [
99
+ {
100
+ **dp,
101
+ "id": dp["__id__"],
102
+ "distance": dp["__metrics__"],
103
+ "created_at": dp.get("__created_at__"),
104
+ }
105
+ for dp in results
106
+ ]
 
107
  return results
108
 
109
  @property
 
117
  ids: List of vector IDs to be deleted
118
  """
119
  try:
120
+ self._get_client().delete(ids)
 
121
  logger.debug(
122
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
123
  )
 
131
  f"Attempting to delete entity {entity_name} with ID {entity_id}"
132
  )
133
 
134
+ # Check if the entity exists
135
+ if self._get_client().get([entity_id]):
136
+ self._get_client().delete([entity_id])
137
+ logger.debug(f"Successfully deleted entity {entity_name}")
138
+ else:
139
+ logger.debug(f"Entity {entity_name} not found in storage")
 
140
  except Exception as e:
141
  logger.error(f"Error deleting entity {entity_name}: {e}")
142
 
143
  async def delete_entity_relation(self, entity_name: str) -> None:
144
  try:
145
+ storage = getattr(self._get_client(), "_NanoVectorDB__storage")
146
+ relations = [
147
+ dp
148
+ for dp in storage["data"]
149
+ if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
150
+ ]
151
+ logger.debug(
152
+ f"Found {len(relations)} relations for entity {entity_name}"
153
+ )
154
+ ids_to_delete = [relation["__id__"] for relation in relations]
155
+
156
+ if ids_to_delete:
157
+ self._get_client().delete(ids_to_delete)
158
  logger.debug(
159
+ f"Deleted {len(ids_to_delete)} relations for {entity_name}"
160
  )
161
+ else:
162
+ logger.debug(f"No relations found for entity {entity_name}")
 
 
 
 
 
 
 
163
  except Exception as e:
164
  logger.error(f"Error deleting relations for {entity_name}: {e}")
165
 
lightrag/kg/networkx_impl.py CHANGED
@@ -6,12 +6,6 @@ import numpy as np
6
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
7
  from lightrag.utils import logger
8
  from lightrag.base import BaseGraphStorage
9
- from .shared_storage import (
10
- get_storage_lock,
11
- get_namespace_object,
12
- is_multiprocess,
13
- try_initialize_namespace,
14
- )
15
 
16
  import pipmaster as pm
17
 
@@ -23,7 +17,7 @@ if not pm.is_installed("graspologic"):
23
 
24
  import networkx as nx
25
  from graspologic import embed
26
-
27
 
28
  @final
29
  @dataclass
@@ -78,38 +72,23 @@ class NetworkXStorage(BaseGraphStorage):
78
  self._graphml_xml_file = os.path.join(
79
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
80
  )
81
- self._storage_lock = get_storage_lock()
82
-
83
- # check need_init must before get_namespace_object
84
- need_init = try_initialize_namespace(self.namespace)
85
- self._graph = get_namespace_object(self.namespace)
86
-
87
- if need_init:
88
- if is_multiprocess:
89
- preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
90
- self._graph.value = preloaded_graph or nx.Graph()
91
- if preloaded_graph:
92
- logger.info(
93
- f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
94
- )
95
- else:
96
- preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
97
- self._graph = preloaded_graph or nx.Graph()
98
- if preloaded_graph:
99
- logger.info(
100
- f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
101
- )
102
 
 
 
 
 
 
 
 
103
  logger.info("Created new empty graph")
104
-
105
  self._node_embed_algorithms = {
106
  "node2vec": self._node2vec_embed,
107
  }
108
 
109
  def _get_graph(self):
110
- """Get the appropriate graph instance based on multiprocess mode"""
111
- if is_multiprocess:
112
- return self._graph.value
113
  return self._graph
114
 
115
  async def index_done_callback(self) -> None:
@@ -117,54 +96,44 @@ class NetworkXStorage(BaseGraphStorage):
117
  NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file)
118
 
119
  async def has_node(self, node_id: str) -> bool:
120
- with self._storage_lock:
121
- return self._get_graph().has_node(node_id)
122
 
123
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
124
- with self._storage_lock:
125
- return self._get_graph().has_edge(source_node_id, target_node_id)
126
 
127
  async def get_node(self, node_id: str) -> dict[str, str] | None:
128
- with self._storage_lock:
129
- return self._get_graph().nodes.get(node_id)
130
 
131
  async def node_degree(self, node_id: str) -> int:
132
- with self._storage_lock:
133
- return self._get_graph().degree(node_id)
134
 
135
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
136
- with self._storage_lock:
137
- return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id)
138
 
139
  async def get_edge(
140
  self, source_node_id: str, target_node_id: str
141
  ) -> dict[str, str] | None:
142
- with self._storage_lock:
143
- return self._get_graph().edges.get((source_node_id, target_node_id))
144
 
145
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
146
- with self._storage_lock:
147
- if self._get_graph().has_node(source_node_id):
148
- return list(self._get_graph().edges(source_node_id))
149
- return None
150
 
151
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
152
- with self._storage_lock:
153
- self._get_graph().add_node(node_id, **node_data)
154
 
155
  async def upsert_edge(
156
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
157
  ) -> None:
158
- with self._storage_lock:
159
- self._get_graph().add_edge(source_node_id, target_node_id, **edge_data)
160
 
161
  async def delete_node(self, node_id: str) -> None:
162
- with self._storage_lock:
163
- if self._get_graph().has_node(node_id):
164
- self._get_graph().remove_node(node_id)
165
- logger.debug(f"Node {node_id} deleted from the graph.")
166
- else:
167
- logger.warning(f"Node {node_id} not found in the graph for deletion.")
168
 
169
  async def embed_nodes(
170
  self, algorithm: str
@@ -175,13 +144,12 @@ class NetworkXStorage(BaseGraphStorage):
175
 
176
  # TODO: NOT USED
177
  async def _node2vec_embed(self):
178
- with self._storage_lock:
179
- graph = self._get_graph()
180
- embeddings, nodes = embed.node2vec_embed(
181
- graph,
182
- **self.global_config["node2vec_params"],
183
- )
184
- nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
185
  return embeddings, nodes_ids
186
 
187
  def remove_nodes(self, nodes: list[str]):
@@ -190,11 +158,10 @@ class NetworkXStorage(BaseGraphStorage):
190
  Args:
191
  nodes: List of node IDs to be deleted
192
  """
193
- with self._storage_lock:
194
- graph = self._get_graph()
195
- for node in nodes:
196
- if graph.has_node(node):
197
- graph.remove_node(node)
198
 
199
  def remove_edges(self, edges: list[tuple[str, str]]):
200
  """Delete multiple edges
@@ -202,11 +169,10 @@ class NetworkXStorage(BaseGraphStorage):
202
  Args:
203
  edges: List of edges to be deleted, each edge is a (source, target) tuple
204
  """
205
- with self._storage_lock:
206
- graph = self._get_graph()
207
- for source, target in edges:
208
- if graph.has_edge(source, target):
209
- graph.remove_edge(source, target)
210
 
211
  async def get_all_labels(self) -> list[str]:
212
  """
@@ -214,10 +180,9 @@ class NetworkXStorage(BaseGraphStorage):
214
  Returns:
215
  [label1, label2, ...] # Alphabetically sorted label list
216
  """
217
- with self._storage_lock:
218
- labels = set()
219
- for node in self._get_graph().nodes():
220
- labels.add(str(node)) # Add node id as a label
221
 
222
  # Return sorted list
223
  return sorted(list(labels))
@@ -239,88 +204,87 @@ class NetworkXStorage(BaseGraphStorage):
239
  seen_nodes = set()
240
  seen_edges = set()
241
 
242
- with self._storage_lock:
243
- graph = self._get_graph()
244
-
245
- # Handle special case for "*" label
246
- if node_label == "*":
247
- # For "*", return the entire graph including all nodes and edges
248
- subgraph = (
249
- graph.copy()
250
- ) # Create a copy to avoid modifying the original graph
251
- else:
252
- # Find nodes with matching node id (partial match)
253
- nodes_to_explore = []
254
- for n, attr in graph.nodes(data=True):
255
- if node_label in str(n): # Use partial matching
256
- nodes_to_explore.append(n)
257
-
258
- if not nodes_to_explore:
259
- logger.warning(f"No nodes found with label {node_label}")
260
- return result
261
-
262
- # Get subgraph using ego_graph
263
- subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
264
-
265
- # Check if number of nodes exceeds max_graph_nodes
266
- max_graph_nodes = 500
267
- if len(subgraph.nodes()) > max_graph_nodes:
268
- origin_nodes = len(subgraph.nodes())
269
- node_degrees = dict(subgraph.degree())
270
- top_nodes = sorted(
271
- node_degrees.items(), key=lambda x: x[1], reverse=True
272
- )[:max_graph_nodes]
273
- top_node_ids = [node[0] for node in top_nodes]
274
- # Create new subgraph with only top nodes
275
- subgraph = subgraph.subgraph(top_node_ids)
276
- logger.info(
277
- f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
278
- )
279
 
280
- # Add nodes to result
281
- for node in subgraph.nodes():
282
- if str(node) in seen_nodes:
283
- continue
284
-
285
- node_data = dict(subgraph.nodes[node])
286
- # Get entity_type as labels
287
- labels = []
288
- if "entity_type" in node_data:
289
- if isinstance(node_data["entity_type"], list):
290
- labels.extend(node_data["entity_type"])
291
- else:
292
- labels.append(node_data["entity_type"])
293
-
294
- # Create node with properties
295
- node_properties = {k: v for k, v in node_data.items()}
296
-
297
- result.nodes.append(
298
- KnowledgeGraphNode(
299
- id=str(node), labels=[str(node)], properties=node_properties
300
- )
301
  )
302
- seen_nodes.add(str(node))
303
-
304
- # Add edges to result
305
- for edge in subgraph.edges():
306
- source, target = edge
307
- edge_id = f"{source}-{target}"
308
- if edge_id in seen_edges:
309
- continue
310
-
311
- edge_data = dict(subgraph.edges[edge])
312
-
313
- # Create edge with complete information
314
- result.edges.append(
315
- KnowledgeGraphEdge(
316
- id=edge_id,
317
- type="DIRECTED",
318
- source=str(source),
319
- target=str(target),
320
- properties=edge_data,
321
- )
322
  )
323
- seen_edges.add(edge_id)
 
324
 
325
  logger.info(
326
  f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
 
6
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
7
  from lightrag.utils import logger
8
  from lightrag.base import BaseGraphStorage
 
 
 
 
 
 
9
 
10
  import pipmaster as pm
11
 
 
17
 
18
  import networkx as nx
19
  from graspologic import embed
20
+ from threading import Lock as ThreadLock
21
 
22
  @final
23
  @dataclass
 
72
  self._graphml_xml_file = os.path.join(
73
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
74
  )
75
+ self._storage_lock = ThreadLock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ with self._storage_lock:
78
+ preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
79
+ if preloaded_graph is not None:
80
+ logger.info(
81
+ f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
82
+ )
83
+ else:
84
  logger.info("Created new empty graph")
85
+ self._graph = preloaded_graph or nx.Graph()
86
  self._node_embed_algorithms = {
87
  "node2vec": self._node2vec_embed,
88
  }
89
 
90
  def _get_graph(self):
91
+ """Check if the shtorage should be reloaded"""
 
 
92
  return self._graph
93
 
94
  async def index_done_callback(self) -> None:
 
96
  NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file)
97
 
98
  async def has_node(self, node_id: str) -> bool:
99
+ return self._get_graph().has_node(node_id)
 
100
 
101
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
102
+ return self._get_graph().has_edge(source_node_id, target_node_id)
 
103
 
104
  async def get_node(self, node_id: str) -> dict[str, str] | None:
105
+ return self._get_graph().nodes.get(node_id)
 
106
 
107
  async def node_degree(self, node_id: str) -> int:
108
+ return self._get_graph().degree(node_id)
 
109
 
110
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
111
+ return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id)
 
112
 
113
  async def get_edge(
114
  self, source_node_id: str, target_node_id: str
115
  ) -> dict[str, str] | None:
116
+ return self._get_graph().edges.get((source_node_id, target_node_id))
 
117
 
118
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
119
+ if self._get_graph().has_node(source_node_id):
120
+ return list(self._get_graph().edges(source_node_id))
121
+ return None
 
122
 
123
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
124
+ self._get_graph().add_node(node_id, **node_data)
 
125
 
126
  async def upsert_edge(
127
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
128
  ) -> None:
129
+ self._get_graph().add_edge(source_node_id, target_node_id, **edge_data)
 
130
 
131
  async def delete_node(self, node_id: str) -> None:
132
+ if self._get_graph().has_node(node_id):
133
+ self._get_graph().remove_node(node_id)
134
+ logger.debug(f"Node {node_id} deleted from the graph.")
135
+ else:
136
+ logger.warning(f"Node {node_id} not found in the graph for deletion.")
 
137
 
138
  async def embed_nodes(
139
  self, algorithm: str
 
144
 
145
  # TODO: NOT USED
146
  async def _node2vec_embed(self):
147
+ graph = self._get_graph()
148
+ embeddings, nodes = embed.node2vec_embed(
149
+ graph,
150
+ **self.global_config["node2vec_params"],
151
+ )
152
+ nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
 
153
  return embeddings, nodes_ids
154
 
155
  def remove_nodes(self, nodes: list[str]):
 
158
  Args:
159
  nodes: List of node IDs to be deleted
160
  """
161
+ graph = self._get_graph()
162
+ for node in nodes:
163
+ if graph.has_node(node):
164
+ graph.remove_node(node)
 
165
 
166
  def remove_edges(self, edges: list[tuple[str, str]]):
167
  """Delete multiple edges
 
169
  Args:
170
  edges: List of edges to be deleted, each edge is a (source, target) tuple
171
  """
172
+ graph = self._get_graph()
173
+ for source, target in edges:
174
+ if graph.has_edge(source, target):
175
+ graph.remove_edge(source, target)
 
176
 
177
  async def get_all_labels(self) -> list[str]:
178
  """
 
180
  Returns:
181
  [label1, label2, ...] # Alphabetically sorted label list
182
  """
183
+ labels = set()
184
+ for node in self._get_graph().nodes():
185
+ labels.add(str(node)) # Add node id as a label
 
186
 
187
  # Return sorted list
188
  return sorted(list(labels))
 
204
  seen_nodes = set()
205
  seen_edges = set()
206
 
207
+ graph = self._get_graph()
208
+
209
+ # Handle special case for "*" label
210
+ if node_label == "*":
211
+ # For "*", return the entire graph including all nodes and edges
212
+ subgraph = (
213
+ graph.copy()
214
+ ) # Create a copy to avoid modifying the original graph
215
+ else:
216
+ # Find nodes with matching node id (partial match)
217
+ nodes_to_explore = []
218
+ for n, attr in graph.nodes(data=True):
219
+ if node_label in str(n): # Use partial matching
220
+ nodes_to_explore.append(n)
221
+
222
+ if not nodes_to_explore:
223
+ logger.warning(f"No nodes found with label {node_label}")
224
+ return result
225
+
226
+ # Get subgraph using ego_graph
227
+ subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
228
+
229
+ # Check if number of nodes exceeds max_graph_nodes
230
+ max_graph_nodes = 500
231
+ if len(subgraph.nodes()) > max_graph_nodes:
232
+ origin_nodes = len(subgraph.nodes())
233
+ node_degrees = dict(subgraph.degree())
234
+ top_nodes = sorted(
235
+ node_degrees.items(), key=lambda x: x[1], reverse=True
236
+ )[:max_graph_nodes]
237
+ top_node_ids = [node[0] for node in top_nodes]
238
+ # Create new subgraph with only top nodes
239
+ subgraph = subgraph.subgraph(top_node_ids)
240
+ logger.info(
241
+ f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
242
+ )
 
243
 
244
+ # Add nodes to result
245
+ for node in subgraph.nodes():
246
+ if str(node) in seen_nodes:
247
+ continue
248
+
249
+ node_data = dict(subgraph.nodes[node])
250
+ # Get entity_type as labels
251
+ labels = []
252
+ if "entity_type" in node_data:
253
+ if isinstance(node_data["entity_type"], list):
254
+ labels.extend(node_data["entity_type"])
255
+ else:
256
+ labels.append(node_data["entity_type"])
257
+
258
+ # Create node with properties
259
+ node_properties = {k: v for k, v in node_data.items()}
260
+
261
+ result.nodes.append(
262
+ KnowledgeGraphNode(
263
+ id=str(node), labels=[str(node)], properties=node_properties
 
264
  )
265
+ )
266
+ seen_nodes.add(str(node))
267
+
268
+ # Add edges to result
269
+ for edge in subgraph.edges():
270
+ source, target = edge
271
+ edge_id = f"{source}-{target}"
272
+ if edge_id in seen_edges:
273
+ continue
274
+
275
+ edge_data = dict(subgraph.edges[edge])
276
+
277
+ # Create edge with complete information
278
+ result.edges.append(
279
+ KnowledgeGraphEdge(
280
+ id=edge_id,
281
+ type="DIRECTED",
282
+ source=str(source),
283
+ target=str(target),
284
+ properties=edge_data,
285
  )
286
+ )
287
+ seen_edges.add(edge_id)
288
 
289
  logger.info(
290
  f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
lightrag/kg/shared_storage.py CHANGED
@@ -20,15 +20,12 @@ LockType = Union[ProcessLock, ThreadLock]
20
  _manager = None
21
  _initialized = None
22
  is_multiprocess = None
 
23
 
24
  # shared data for storage across processes
25
  _shared_dicts: Optional[Dict[str, Any]] = None
26
- _share_objects: Optional[Dict[str, Any]] = None
27
  _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
28
 
29
- _global_lock: Optional[LockType] = None
30
-
31
-
32
  def initialize_share_data(workers: int = 1):
33
  """
34
  Initialize shared storage data for single or multi-process mode.
@@ -53,7 +50,6 @@ def initialize_share_data(workers: int = 1):
53
  is_multiprocess, \
54
  _global_lock, \
55
  _shared_dicts, \
56
- _share_objects, \
57
  _init_flags, \
58
  _initialized
59
 
@@ -72,7 +68,6 @@ def initialize_share_data(workers: int = 1):
72
  _global_lock = _manager.Lock()
73
  # Create shared dictionaries with manager
74
  _shared_dicts = _manager.dict()
75
- _share_objects = _manager.dict()
76
  _init_flags = (
77
  _manager.dict()
78
  ) # Use shared dictionary to store initialization flags
@@ -83,7 +78,6 @@ def initialize_share_data(workers: int = 1):
83
  is_multiprocess = False
84
  _global_lock = ThreadLock()
85
  _shared_dicts = {}
86
- _share_objects = {}
87
  _init_flags = {}
88
  direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
89
 
@@ -99,11 +93,7 @@ def try_initialize_namespace(namespace: str) -> bool:
99
  global _init_flags, _manager
100
 
101
  if _init_flags is None:
102
- direct_log(
103
- f"Error: try to create nanmespace before Shared-Data is initialized, pid={os.getpid()}",
104
- level="ERROR",
105
- )
106
- raise ValueError("Shared dictionaries not initialized")
107
 
108
  if namespace not in _init_flags:
109
  _init_flags[namespace] = True
@@ -113,43 +103,9 @@ def try_initialize_namespace(namespace: str) -> bool:
113
  return False
114
 
115
 
116
- def _get_global_lock() -> LockType:
117
- return _global_lock
118
-
119
-
120
  def get_storage_lock() -> LockType:
121
  """return storage lock for data consistency"""
122
- return _get_global_lock()
123
-
124
-
125
- def get_scan_lock() -> LockType:
126
- """return scan_progress lock for data consistency"""
127
- return get_storage_lock()
128
-
129
-
130
- def get_namespace_object(namespace: str) -> Any:
131
- """Get an object for specific namespace"""
132
-
133
- if _share_objects is None:
134
- direct_log(
135
- f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}",
136
- level="ERROR",
137
- )
138
- raise ValueError("Shared dictionaries not initialized")
139
-
140
- lock = _get_global_lock()
141
- with lock:
142
- if namespace not in _share_objects:
143
- if namespace not in _share_objects:
144
- if is_multiprocess:
145
- _share_objects[namespace] = _manager.Value("O", None)
146
- else:
147
- _share_objects[namespace] = None
148
- direct_log(
149
- f"Created namespace: {namespace}(type={type(_share_objects[namespace])})"
150
- )
151
-
152
- return _share_objects[namespace]
153
 
154
 
155
  def get_namespace_data(namespace: str) -> Dict[str, Any]:
@@ -161,7 +117,7 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]:
161
  )
162
  raise ValueError("Shared dictionaries not initialized")
163
 
164
- lock = _get_global_lock()
165
  with lock:
166
  if namespace not in _shared_dicts:
167
  if is_multiprocess and _manager is not None:
@@ -175,11 +131,6 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]:
175
  return _shared_dicts[namespace]
176
 
177
 
178
- def get_scan_progress() -> Dict[str, Any]:
179
- """get storage space for document scanning progress data"""
180
- return get_namespace_data("scan_progress")
181
-
182
-
183
  def finalize_share_data():
184
  """
185
  Release shared resources and clean up.
@@ -195,7 +146,6 @@ def finalize_share_data():
195
  is_multiprocess, \
196
  _global_lock, \
197
  _shared_dicts, \
198
- _share_objects, \
199
  _init_flags, \
200
  _initialized
201
 
@@ -216,8 +166,6 @@ def finalize_share_data():
216
  # Clear shared dictionaries first
217
  if _shared_dicts is not None:
218
  _shared_dicts.clear()
219
- if _share_objects is not None:
220
- _share_objects.clear()
221
  if _init_flags is not None:
222
  _init_flags.clear()
223
 
@@ -234,7 +182,6 @@ def finalize_share_data():
234
  _initialized = None
235
  is_multiprocess = None
236
  _shared_dicts = None
237
- _share_objects = None
238
  _init_flags = None
239
  _global_lock = None
240
 
 
20
  _manager = None
21
  _initialized = None
22
  is_multiprocess = None
23
+ _global_lock: Optional[LockType] = None
24
 
25
  # shared data for storage across processes
26
  _shared_dicts: Optional[Dict[str, Any]] = None
 
27
  _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
28
 
 
 
 
29
  def initialize_share_data(workers: int = 1):
30
  """
31
  Initialize shared storage data for single or multi-process mode.
 
50
  is_multiprocess, \
51
  _global_lock, \
52
  _shared_dicts, \
 
53
  _init_flags, \
54
  _initialized
55
 
 
68
  _global_lock = _manager.Lock()
69
  # Create shared dictionaries with manager
70
  _shared_dicts = _manager.dict()
 
71
  _init_flags = (
72
  _manager.dict()
73
  ) # Use shared dictionary to store initialization flags
 
78
  is_multiprocess = False
79
  _global_lock = ThreadLock()
80
  _shared_dicts = {}
 
81
  _init_flags = {}
82
  direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
83
 
 
93
  global _init_flags, _manager
94
 
95
  if _init_flags is None:
96
+ raise ValueError("Try to create nanmespace before Shared-Data is initialized")
 
 
 
 
97
 
98
  if namespace not in _init_flags:
99
  _init_flags[namespace] = True
 
103
  return False
104
 
105
 
 
 
 
 
106
  def get_storage_lock() -> LockType:
107
  """return storage lock for data consistency"""
108
+ return _global_lock
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
 
111
  def get_namespace_data(namespace: str) -> Dict[str, Any]:
 
117
  )
118
  raise ValueError("Shared dictionaries not initialized")
119
 
120
+ lock = get_storage_lock()
121
  with lock:
122
  if namespace not in _shared_dicts:
123
  if is_multiprocess and _manager is not None:
 
131
  return _shared_dicts[namespace]
132
 
133
 
 
 
 
 
 
134
  def finalize_share_data():
135
  """
136
  Release shared resources and clean up.
 
146
  is_multiprocess, \
147
  _global_lock, \
148
  _shared_dicts, \
 
149
  _init_flags, \
150
  _initialized
151
 
 
166
  # Clear shared dictionaries first
167
  if _shared_dicts is not None:
168
  _shared_dicts.clear()
 
 
169
  if _init_flags is not None:
170
  _init_flags.clear()
171
 
 
182
  _initialized = None
183
  is_multiprocess = None
184
  _shared_dicts = None
 
185
  _init_flags = None
186
  _global_lock = None
187