yangdx commited on
Commit
5d78930
·
1 Parent(s): d2b7a97

Refactor storage implementations to support both single and multi-process modes

Browse files

• Add shared storage management module
• Support process/thread lock based on mode

lightrag/api/lightrag_server.py CHANGED
@@ -406,9 +406,6 @@ def create_app(args):
406
 
407
  def get_application():
408
  """Factory function for creating the FastAPI application"""
409
- from .utils_api import initialize_manager
410
- initialize_manager()
411
-
412
  # Get args from environment variable
413
  args_json = os.environ.get('LIGHTRAG_ARGS')
414
  if not args_json:
@@ -428,6 +425,12 @@ def main():
428
  # Save args to environment variable for child processes
429
  os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args))
430
 
 
 
 
 
 
 
431
  # Configure uvicorn logging
432
  logging.config.dictConfig({
433
  "version": 1,
 
406
 
407
  def get_application():
408
  """Factory function for creating the FastAPI application"""
 
 
 
409
  # Get args from environment variable
410
  args_json = os.environ.get('LIGHTRAG_ARGS')
411
  if not args_json:
 
425
  # Save args to environment variable for child processes
426
  os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args))
427
 
428
+ if args.workers > 1:
429
+ from lightrag.kg.shared_storage import initialize_manager
430
+ initialize_manager()
431
+ import lightrag.kg.shared_storage as shared_storage
432
+ shared_storage.is_multiprocess = True
433
+
434
  # Configure uvicorn logging
435
  logging.config.dictConfig({
436
  "version": 1,
lightrag/api/routers/document_routes.py CHANGED
@@ -18,12 +18,10 @@ from pydantic import BaseModel, Field, field_validator
18
 
19
  from lightrag import LightRAG
20
  from lightrag.base import DocProcessingStatus, DocStatus
21
- from ..utils_api import (
22
- get_api_key_dependency,
23
- scan_progress,
24
- update_scan_progress_if_not_scanning,
25
- update_scan_progress,
26
- reset_scan_progress,
27
  )
28
 
29
 
@@ -378,23 +376,51 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
378
 
379
  async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
380
  """Background task to scan and index documents"""
381
- if not update_scan_progress_if_not_scanning():
382
- ASCIIColors.info(
383
- "Skip document scanning(another scanning is active)"
384
- )
385
- return
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  try:
388
  new_files = doc_manager.scan_directory_for_new_files()
389
  total_files = len(new_files)
390
- update_scan_progress("", total_files, 0) # Initialize progress
 
 
 
 
 
391
 
392
  logging.info(f"Found {total_files} new files to index.")
393
  for idx, file_path in enumerate(new_files):
394
  try:
395
- update_scan_progress(os.path.basename(file_path), total_files, idx)
 
 
 
 
 
 
396
  await pipeline_index_file(rag, file_path)
397
- update_scan_progress(os.path.basename(file_path), total_files, idx + 1)
 
 
 
 
 
 
398
 
399
  except Exception as e:
400
  logging.error(f"Error indexing file {file_path}: {str(e)}")
@@ -402,7 +428,13 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
402
  except Exception as e:
403
  logging.error(f"Error during scanning process: {str(e)}")
404
  finally:
405
- reset_scan_progress()
 
 
 
 
 
 
406
 
407
 
408
  def create_document_routes(
@@ -427,7 +459,7 @@ def create_document_routes(
427
  return {"status": "scanning_started"}
428
 
429
  @router.get("/scan-progress")
430
- async def get_scan_progress():
431
  """
432
  Get the current progress of the document scanning process.
433
 
@@ -439,7 +471,7 @@ def create_document_routes(
439
  - total_files: Total number of files to process
440
  - progress: Percentage of completion
441
  """
442
- return dict(scan_progress)
443
 
444
  @router.post("/upload", dependencies=[Depends(optional_api_key)])
445
  async def upload_to_input_dir(
 
18
 
19
  from lightrag import LightRAG
20
  from lightrag.base import DocProcessingStatus, DocStatus
21
+ from ..utils_api import get_api_key_dependency
22
+ from lightrag.kg.shared_storage import (
23
+ get_scan_progress,
24
+ get_scan_lock,
 
 
25
  )
26
 
27
 
 
376
 
377
  async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
378
  """Background task to scan and index documents"""
379
+ scan_progress = get_scan_progress()
380
+ scan_lock = get_scan_lock()
381
+
382
+ with scan_lock:
383
+ if scan_progress["is_scanning"]:
384
+ ASCIIColors.info(
385
+ "Skip document scanning(another scanning is active)"
386
+ )
387
+ return
388
+ scan_progress.update({
389
+ "is_scanning": True,
390
+ "current_file": "",
391
+ "indexed_count": 0,
392
+ "total_files": 0,
393
+ "progress": 0,
394
+ })
395
 
396
  try:
397
  new_files = doc_manager.scan_directory_for_new_files()
398
  total_files = len(new_files)
399
+ scan_progress.update({
400
+ "current_file": "",
401
+ "total_files": total_files,
402
+ "indexed_count": 0,
403
+ "progress": 0,
404
+ })
405
 
406
  logging.info(f"Found {total_files} new files to index.")
407
  for idx, file_path in enumerate(new_files):
408
  try:
409
+ progress = (idx / total_files * 100) if total_files > 0 else 0
410
+ scan_progress.update({
411
+ "current_file": os.path.basename(file_path),
412
+ "indexed_count": idx,
413
+ "progress": progress,
414
+ })
415
+
416
  await pipeline_index_file(rag, file_path)
417
+
418
+ progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0
419
+ scan_progress.update({
420
+ "current_file": os.path.basename(file_path),
421
+ "indexed_count": idx + 1,
422
+ "progress": progress,
423
+ })
424
 
425
  except Exception as e:
426
  logging.error(f"Error indexing file {file_path}: {str(e)}")
 
428
  except Exception as e:
429
  logging.error(f"Error during scanning process: {str(e)}")
430
  finally:
431
+ scan_progress.update({
432
+ "is_scanning": False,
433
+ "current_file": "",
434
+ "indexed_count": 0,
435
+ "total_files": 0,
436
+ "progress": 0,
437
+ })
438
 
439
 
440
  def create_document_routes(
 
459
  return {"status": "scanning_started"}
460
 
461
  @router.get("/scan-progress")
462
+ async def get_scanning_progress():
463
  """
464
  Get the current progress of the document scanning process.
465
 
 
471
  - total_files: Total number of files to process
472
  - progress: Percentage of completion
473
  """
474
+ return dict(get_scan_progress())
475
 
476
  @router.post("/upload", dependencies=[Depends(optional_api_key)])
477
  async def upload_to_input_dir(
lightrag/api/utils_api.py CHANGED
@@ -6,7 +6,6 @@ import os
6
  import argparse
7
  from typing import Optional
8
  import sys
9
- from multiprocessing import Manager
10
  from ascii_colors import ASCIIColors
11
  from lightrag.api import __api_version__
12
  from fastapi import HTTPException, Security
@@ -17,66 +16,6 @@ from starlette.status import HTTP_403_FORBIDDEN
17
  # Load environment variables
18
  load_dotenv(override=True)
19
 
20
- # Global variables for manager and shared state
21
- manager = None
22
- scan_progress = None
23
- scan_lock = None
24
-
25
- def initialize_manager():
26
- """Initialize manager and shared state for cross-process communication"""
27
- global manager, scan_progress, scan_lock
28
- if manager is None:
29
- manager = Manager()
30
- scan_progress = manager.dict({
31
- "is_scanning": False,
32
- "current_file": "",
33
- "indexed_count": 0,
34
- "total_files": 0,
35
- "progress": 0,
36
- })
37
- scan_lock = manager.Lock()
38
-
39
- def update_scan_progress_if_not_scanning():
40
- """
41
- Atomically check if scanning is not in progress and update scan_progress if it's not.
42
- Returns True if the update was successful, False if scanning was already in progress.
43
- """
44
- with scan_lock:
45
- if not scan_progress["is_scanning"]:
46
- scan_progress.update({
47
- "is_scanning": True,
48
- "current_file": "",
49
- "indexed_count": 0,
50
- "total_files": 0,
51
- "progress": 0,
52
- })
53
- return True
54
- return False
55
-
56
- def update_scan_progress(current_file: str, total_files: int, indexed_count: int):
57
- """
58
- Atomically update scan progress information.
59
- """
60
- progress = (indexed_count / total_files * 100) if total_files > 0 else 0
61
- scan_progress.update({
62
- "current_file": current_file,
63
- "indexed_count": indexed_count,
64
- "total_files": total_files,
65
- "progress": progress,
66
- })
67
-
68
- def reset_scan_progress():
69
- """
70
- Atomically reset scan progress to initial state.
71
- """
72
- scan_progress.update({
73
- "is_scanning": False,
74
- "current_file": "",
75
- "indexed_count": 0,
76
- "total_files": 0,
77
- "progress": 0,
78
- })
79
-
80
 
81
  class OllamaServerInfos:
82
  # Constants for emulated Ollama model information
 
6
  import argparse
7
  from typing import Optional
8
  import sys
 
9
  from ascii_colors import ASCIIColors
10
  from lightrag.api import __api_version__
11
  from fastapi import HTTPException, Security
 
16
  # Load environment variables
17
  load_dotenv(override=True)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class OllamaServerInfos:
21
  # Constants for emulated Ollama model information
lightrag/kg/faiss_impl.py CHANGED
@@ -2,48 +2,21 @@ import os
2
  import time
3
  import asyncio
4
  from typing import Any, final
5
- import threading
6
  import json
7
  import numpy as np
8
 
9
  from dataclasses import dataclass
10
  import pipmaster as pm
11
- from lightrag.api.utils_api import manager as main_process_manager
12
 
13
- from lightrag.utils import (
14
- logger,
15
- compute_mdhash_id,
16
- )
17
- from lightrag.base import (
18
- BaseVectorStorage,
19
- )
20
 
21
  if not pm.is_installed("faiss"):
22
  pm.install("faiss")
23
 
24
  import faiss # type: ignore
25
 
26
- # Global variables for shared memory management
27
- _init_lock = threading.Lock()
28
- _manager = None
29
- _shared_indices = None
30
- _shared_meta = None
31
-
32
-
33
- def _get_manager():
34
- """Get or create the global manager instance"""
35
- global _manager, _shared_indices, _shared_meta
36
- with _init_lock:
37
- if _manager is None:
38
- try:
39
- _manager = main_process_manager
40
- _shared_indices = _manager.dict()
41
- _shared_meta = _manager.dict()
42
- except Exception as e:
43
- logger.error(f"Failed to initialize shared memory manager: {e}")
44
- raise RuntimeError(f"Shared memory initialization failed: {e}")
45
- return _manager
46
-
47
 
48
  @final
49
  @dataclass
@@ -72,48 +45,29 @@ class FaissVectorDBStorage(BaseVectorStorage):
72
  self._max_batch_size = self.global_config["embedding_batch_num"]
73
  # Embedding dimension (e.g. 768) must match your embedding function
74
  self._dim = self.embedding_func.embedding_dim
 
75
 
76
- # Ensure manager is initialized
77
- _get_manager()
78
 
79
- # Get or create namespace index and metadata
80
- if self.namespace not in _shared_indices:
81
- with _init_lock:
82
- if self.namespace not in _shared_indices:
83
- try:
84
- # Create an empty Faiss index for inner product
85
- index = faiss.IndexFlatIP(self._dim)
86
- meta = {}
87
-
88
- # Load existing index if available
89
- if os.path.exists(self._faiss_index_file):
90
- try:
91
- index = faiss.read_index(self._faiss_index_file)
92
- with open(self._meta_file, "r", encoding="utf-8") as f:
93
- stored_dict = json.load(f)
94
- # Convert string keys back to int
95
- meta = {int(k): v for k, v in stored_dict.items()}
96
- logger.info(
97
- f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}"
98
- )
99
- except Exception as e:
100
- logger.error(f"Failed to load Faiss index or metadata: {e}")
101
- logger.warning("Starting with an empty Faiss index.")
102
- index = faiss.IndexFlatIP(self._dim)
103
- meta = {}
104
-
105
- _shared_indices[self.namespace] = index
106
- _shared_meta[self.namespace] = meta
107
- except Exception as e:
108
- logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}")
109
- raise RuntimeError(f"Faiss index initialization failed: {e}")
110
-
111
- try:
112
- self._index = _shared_indices[self.namespace]
113
- self._id_to_meta = _shared_meta[self.namespace]
114
- except Exception as e:
115
- logger.error(f"Failed to access shared memory: {e}")
116
- raise RuntimeError(f"Cannot access shared memory: {e}")
117
 
118
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
119
  """
@@ -168,32 +122,36 @@ class FaissVectorDBStorage(BaseVectorStorage):
168
  # Normalize embeddings for cosine similarity (in-place)
169
  faiss.normalize_L2(embeddings)
170
 
171
- # Upsert logic:
172
- # 1. Identify which vectors to remove if they exist
173
- # 2. Remove them
174
- # 3. Add the new vectors
175
- existing_ids_to_remove = []
176
- for meta, emb in zip(list_data, embeddings):
177
- faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
178
- if faiss_internal_id is not None:
179
- existing_ids_to_remove.append(faiss_internal_id)
180
-
181
- if existing_ids_to_remove:
182
- self._remove_faiss_ids(existing_ids_to_remove)
183
-
184
- # Step 2: Add new vectors
185
- start_idx = self._index.ntotal
186
- self._index.add(embeddings)
187
-
188
- # Step 3: Store metadata + vector for each new ID
189
- for i, meta in enumerate(list_data):
190
- fid = start_idx + i
191
- # Store the raw vector so we can rebuild if something is removed
192
- meta["__vector__"] = embeddings[i].tolist()
193
- self._id_to_meta[fid] = meta
194
-
195
- logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
196
- return [m["__id__"] for m in list_data]
 
 
 
 
197
 
198
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
199
  """
@@ -209,54 +167,57 @@ class FaissVectorDBStorage(BaseVectorStorage):
209
  )
210
 
211
  # Perform the similarity search
212
- distances, indices = self._index.search(embedding, top_k)
213
-
214
- distances = distances[0]
215
- indices = indices[0]
216
-
217
- results = []
218
- for dist, idx in zip(distances, indices):
219
- if idx == -1:
220
- # Faiss returns -1 if no neighbor
221
- continue
222
-
223
- # Cosine similarity threshold
224
- if dist < self.cosine_better_than_threshold:
225
- continue
226
-
227
- meta = self._id_to_meta.get(idx, {})
228
- results.append(
229
- {
230
- **meta,
231
- "id": meta.get("__id__"),
232
- "distance": float(dist),
233
- "created_at": meta.get("__created_at__"),
234
- }
235
- )
236
-
237
- return results
 
238
 
239
  @property
240
  def client_storage(self):
241
  # Return whatever structure LightRAG might need for debugging
242
- return {"data": list(self._id_to_meta.values())}
 
243
 
244
  async def delete(self, ids: list[str]):
245
  """
246
  Delete vectors for the provided custom IDs.
247
  """
248
  logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
249
- to_remove = []
250
- for cid in ids:
251
- fid = self._find_faiss_id_by_custom_id(cid)
252
- if fid is not None:
253
- to_remove.append(fid)
254
-
255
- if to_remove:
256
- self._remove_faiss_ids(to_remove)
257
- logger.info(
258
- f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
259
- )
 
260
 
261
  async def delete_entity(self, entity_name: str) -> None:
262
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
@@ -268,18 +229,20 @@ class FaissVectorDBStorage(BaseVectorStorage):
268
  Delete relations for a given entity by scanning metadata.
269
  """
270
  logger.debug(f"Searching relations for entity {entity_name}")
271
- relations = []
272
- for fid, meta in self._id_to_meta.items():
273
- if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
274
- relations.append(fid)
 
275
 
276
- logger.debug(f"Found {len(relations)} relations for {entity_name}")
277
- if relations:
278
- self._remove_faiss_ids(relations)
279
- logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
280
 
281
  async def index_done_callback(self) -> None:
282
- self._save_faiss_index()
 
283
 
284
  # --------------------------------------------------------------------------------
285
  # Internal helper methods
@@ -289,10 +252,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
289
  """
290
  Return the Faiss internal ID for a given custom ID, or None if not found.
291
  """
292
- for fid, meta in self._id_to_meta.items():
293
- if meta.get("__id__") == custom_id:
294
- return fid
295
- return None
 
296
 
297
  def _remove_faiss_ids(self, fid_list):
298
  """
@@ -300,39 +264,45 @@ class FaissVectorDBStorage(BaseVectorStorage):
300
  Because IndexFlatIP doesn't support 'removals',
301
  we rebuild the index excluding those vectors.
302
  """
303
- keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
304
-
305
- # Rebuild the index
306
- vectors_to_keep = []
307
- new_id_to_meta = {}
308
- for new_fid, old_fid in enumerate(keep_fids):
309
- vec_meta = self._id_to_meta[old_fid]
310
- vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
311
- new_id_to_meta[new_fid] = vec_meta
312
-
313
- # Re-init index
314
- self._index = faiss.IndexFlatIP(self._dim)
315
- if vectors_to_keep:
316
- arr = np.array(vectors_to_keep, dtype=np.float32)
317
- self._index.add(arr)
318
-
319
- self._id_to_meta = new_id_to_meta
 
 
 
 
 
320
 
321
  def _save_faiss_index(self):
322
  """
323
  Save the current Faiss index + metadata to disk so it can persist across runs.
324
  """
325
- faiss.write_index(self._index, self._faiss_index_file)
 
326
 
327
- # Save metadata dict to JSON. Convert all keys to strings for JSON storage.
328
- # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
329
- # We'll keep the int -> dict, but JSON requires string keys.
330
- serializable_dict = {}
331
- for fid, meta in self._id_to_meta.items():
332
- serializable_dict[str(fid)] = meta
333
 
334
- with open(self._meta_file, "w", encoding="utf-8") as f:
335
- json.dump(serializable_dict, f)
336
 
337
  def _load_faiss_index(self):
338
  """
@@ -345,22 +315,31 @@ class FaissVectorDBStorage(BaseVectorStorage):
345
 
346
  try:
347
  # Load the Faiss index
348
- self._index = faiss.read_index(self._faiss_index_file)
 
 
 
 
 
349
  # Load metadata
350
  with open(self._meta_file, "r", encoding="utf-8") as f:
351
  stored_dict = json.load(f)
352
 
353
  # Convert string keys back to int
354
- self._id_to_meta = {}
355
  for fid_str, meta in stored_dict.items():
356
  fid = int(fid_str)
357
  self._id_to_meta[fid] = meta
358
 
359
  logger.info(
360
- f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
361
  )
362
  except Exception as e:
363
  logger.error(f"Failed to load Faiss index or metadata: {e}")
364
  logger.warning("Starting with an empty Faiss index.")
365
- self._index = faiss.IndexFlatIP(self._dim)
366
- self._id_to_meta = {}
 
 
 
 
 
2
  import time
3
  import asyncio
4
  from typing import Any, final
 
5
  import json
6
  import numpy as np
7
 
8
  from dataclasses import dataclass
9
  import pipmaster as pm
 
10
 
11
+ from lightrag.utils import logger,compute_mdhash_id
12
+ from lightrag.base import BaseVectorStorage
13
+ from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess
 
 
 
 
14
 
15
  if not pm.is_installed("faiss"):
16
  pm.install("faiss")
17
 
18
  import faiss # type: ignore
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @final
22
  @dataclass
 
45
  self._max_batch_size = self.global_config["embedding_batch_num"]
46
  # Embedding dimension (e.g. 768) must match your embedding function
47
  self._dim = self.embedding_func.embedding_dim
48
+ self._storage_lock = get_storage_lock()
49
 
50
+ self._index = get_namespace_object('faiss_indices')
51
+ self._id_to_meta = get_namespace_data('faiss_meta')
52
 
53
+ with self._storage_lock:
54
+ if is_multiprocess:
55
+ if self._index.value is None:
56
+ # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
57
+ # If you have a large number of vectors, you might want IVF or other indexes.
58
+ # For demonstration, we use a simple IndexFlatIP.
59
+ self._index.value = faiss.IndexFlatIP(self._dim)
60
+ else:
61
+ if self._index is None:
62
+ self._index = faiss.IndexFlatIP(self._dim)
63
+
64
+ # Keep a local store for metadata, IDs, etc.
65
+ # Maps <int faiss_id> → metadata (including your original ID).
66
+ self._id_to_meta.update({})
67
+
68
+ # Attempt to load an existing index + metadata from disk
69
+ self._load_faiss_index()
70
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
73
  """
 
122
  # Normalize embeddings for cosine similarity (in-place)
123
  faiss.normalize_L2(embeddings)
124
 
125
+ with self._storage_lock:
126
+ # Upsert logic:
127
+ # 1. Identify which vectors to remove if they exist
128
+ # 2. Remove them
129
+ # 3. Add the new vectors
130
+ existing_ids_to_remove = []
131
+ for meta, emb in zip(list_data, embeddings):
132
+ faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
133
+ if faiss_internal_id is not None:
134
+ existing_ids_to_remove.append(faiss_internal_id)
135
+
136
+ if existing_ids_to_remove:
137
+ self._remove_faiss_ids(existing_ids_to_remove)
138
+
139
+ # Step 2: Add new vectors
140
+ start_idx = (self._index.value if is_multiprocess else self._index).ntotal
141
+ if is_multiprocess:
142
+ self._index.value.add(embeddings)
143
+ else:
144
+ self._index.add(embeddings)
145
+
146
+ # Step 3: Store metadata + vector for each new ID
147
+ for i, meta in enumerate(list_data):
148
+ fid = start_idx + i
149
+ # Store the raw vector so we can rebuild if something is removed
150
+ meta["__vector__"] = embeddings[i].tolist()
151
+ self._id_to_meta.update({fid: meta})
152
+
153
+ logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
154
+ return [m["__id__"] for m in list_data]
155
 
156
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
157
  """
 
167
  )
168
 
169
  # Perform the similarity search
170
+ with self._storage_lock:
171
+ distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k)
172
+
173
+ distances = distances[0]
174
+ indices = indices[0]
175
+
176
+ results = []
177
+ for dist, idx in zip(distances, indices):
178
+ if idx == -1:
179
+ # Faiss returns -1 if no neighbor
180
+ continue
181
+
182
+ # Cosine similarity threshold
183
+ if dist < self.cosine_better_than_threshold:
184
+ continue
185
+
186
+ meta = self._id_to_meta.get(idx, {})
187
+ results.append(
188
+ {
189
+ **meta,
190
+ "id": meta.get("__id__"),
191
+ "distance": float(dist),
192
+ "created_at": meta.get("__created_at__"),
193
+ }
194
+ )
195
+
196
+ return results
197
 
198
  @property
199
  def client_storage(self):
200
  # Return whatever structure LightRAG might need for debugging
201
+ with self._storage_lock:
202
+ return {"data": list(self._id_to_meta.values())}
203
 
204
  async def delete(self, ids: list[str]):
205
  """
206
  Delete vectors for the provided custom IDs.
207
  """
208
  logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
209
+ with self._storage_lock:
210
+ to_remove = []
211
+ for cid in ids:
212
+ fid = self._find_faiss_id_by_custom_id(cid)
213
+ if fid is not None:
214
+ to_remove.append(fid)
215
+
216
+ if to_remove:
217
+ self._remove_faiss_ids(to_remove)
218
+ logger.debug(
219
+ f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
220
+ )
221
 
222
  async def delete_entity(self, entity_name: str) -> None:
223
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
 
229
  Delete relations for a given entity by scanning metadata.
230
  """
231
  logger.debug(f"Searching relations for entity {entity_name}")
232
+ with self._storage_lock:
233
+ relations = []
234
+ for fid, meta in self._id_to_meta.items():
235
+ if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
236
+ relations.append(fid)
237
 
238
+ logger.debug(f"Found {len(relations)} relations for {entity_name}")
239
+ if relations:
240
+ self._remove_faiss_ids(relations)
241
+ logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
242
 
243
  async def index_done_callback(self) -> None:
244
+ with self._storage_lock:
245
+ self._save_faiss_index()
246
 
247
  # --------------------------------------------------------------------------------
248
  # Internal helper methods
 
252
  """
253
  Return the Faiss internal ID for a given custom ID, or None if not found.
254
  """
255
+ with self._storage_lock:
256
+ for fid, meta in self._id_to_meta.items():
257
+ if meta.get("__id__") == custom_id:
258
+ return fid
259
+ return None
260
 
261
  def _remove_faiss_ids(self, fid_list):
262
  """
 
264
  Because IndexFlatIP doesn't support 'removals',
265
  we rebuild the index excluding those vectors.
266
  """
267
+ with self._storage_lock:
268
+ keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
269
+
270
+ # Rebuild the index
271
+ vectors_to_keep = []
272
+ new_id_to_meta = {}
273
+ for new_fid, old_fid in enumerate(keep_fids):
274
+ vec_meta = self._id_to_meta[old_fid]
275
+ vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
276
+ new_id_to_meta[new_fid] = vec_meta
277
+
278
+ # Re-init index
279
+ new_index = faiss.IndexFlatIP(self._dim)
280
+ if vectors_to_keep:
281
+ arr = np.array(vectors_to_keep, dtype=np.float32)
282
+ new_index.add(arr)
283
+ if is_multiprocess:
284
+ self._index.value = new_index
285
+ else:
286
+ self._index = new_index
287
+
288
+ self._id_to_meta.update(new_id_to_meta)
289
 
290
  def _save_faiss_index(self):
291
  """
292
  Save the current Faiss index + metadata to disk so it can persist across runs.
293
  """
294
+ with self._storage_lock:
295
+ faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file)
296
 
297
+ # Save metadata dict to JSON. Convert all keys to strings for JSON storage.
298
+ # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
299
+ # We'll keep the int -> dict, but JSON requires string keys.
300
+ serializable_dict = {}
301
+ for fid, meta in self._id_to_meta.items():
302
+ serializable_dict[str(fid)] = meta
303
 
304
+ with open(self._meta_file, "w", encoding="utf-8") as f:
305
+ json.dump(serializable_dict, f)
306
 
307
  def _load_faiss_index(self):
308
  """
 
315
 
316
  try:
317
  # Load the Faiss index
318
+ loaded_index = faiss.read_index(self._faiss_index_file)
319
+ if is_multiprocess:
320
+ self._index.value = loaded_index
321
+ else:
322
+ self._index = loaded_index
323
+
324
  # Load metadata
325
  with open(self._meta_file, "r", encoding="utf-8") as f:
326
  stored_dict = json.load(f)
327
 
328
  # Convert string keys back to int
329
+ self._id_to_meta.update({})
330
  for fid_str, meta in stored_dict.items():
331
  fid = int(fid_str)
332
  self._id_to_meta[fid] = meta
333
 
334
  logger.info(
335
+ f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}"
336
  )
337
  except Exception as e:
338
  logger.error(f"Failed to load Faiss index or metadata: {e}")
339
  logger.warning("Starting with an empty Faiss index.")
340
+ new_index = faiss.IndexFlatIP(self._dim)
341
+ if is_multiprocess:
342
+ self._index.value = new_index
343
+ else:
344
+ self._index = new_index
345
+ self._id_to_meta.update({})
lightrag/kg/json_doc_status_impl.py CHANGED
@@ -1,7 +1,6 @@
1
  from dataclasses import dataclass
2
  import os
3
  from typing import Any, Union, final
4
- import threading
5
 
6
  from lightrag.base import (
7
  DocProcessingStatus,
@@ -13,26 +12,7 @@ from lightrag.utils import (
13
  logger,
14
  write_json,
15
  )
16
- from lightrag.api.utils_api import manager as main_process_manager
17
-
18
- # Global variables for shared memory management
19
- _init_lock = threading.Lock()
20
- _manager = None
21
- _shared_doc_status_data = None
22
-
23
-
24
- def _get_manager():
25
- """Get or create the global manager instance"""
26
- global _manager, _shared_doc_status_data
27
- with _init_lock:
28
- if _manager is None:
29
- try:
30
- _manager = main_process_manager
31
- _shared_doc_status_data = _manager.dict()
32
- except Exception as e:
33
- logger.error(f"Failed to initialize shared memory manager: {e}")
34
- raise RuntimeError(f"Shared memory initialization failed: {e}")
35
- return _manager
36
 
37
 
38
  @final
@@ -43,45 +23,32 @@ class JsonDocStatusStorage(DocStatusStorage):
43
  def __post_init__(self):
44
  working_dir = self.global_config["working_dir"]
45
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
46
-
47
- # Ensure manager is initialized
48
- _get_manager()
49
-
50
- # Get or create namespace data
51
- if self.namespace not in _shared_doc_status_data:
52
- with _init_lock:
53
- if self.namespace not in _shared_doc_status_data:
54
- try:
55
- initial_data = load_json(self._file_name) or {}
56
- _shared_doc_status_data[self.namespace] = initial_data
57
- except Exception as e:
58
- logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
59
- raise RuntimeError(f"Shared data initialization failed: {e}")
60
-
61
- try:
62
- self._data = _shared_doc_status_data[self.namespace]
63
- logger.info(f"Loaded document status storage with {len(self._data)} records")
64
- except Exception as e:
65
- logger.error(f"Failed to access shared memory: {e}")
66
- raise RuntimeError(f"Cannot access shared memory: {e}")
67
 
68
  async def filter_keys(self, keys: set[str]) -> set[str]:
69
  """Return keys that should be processed (not in storage or not successfully processed)"""
70
- return set(keys) - set(self._data.keys())
 
71
 
72
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
73
  result: list[dict[str, Any]] = []
74
- for id in ids:
75
- data = self._data.get(id, None)
76
- if data:
77
- result.append(data)
 
78
  return result
79
 
80
  async def get_status_counts(self) -> dict[str, int]:
81
  """Get counts of documents in each status"""
82
  counts = {status.value: 0 for status in DocStatus}
83
- for doc in self._data.values():
84
- counts[doc["status"]] += 1
 
85
  return counts
86
 
87
  async def get_docs_by_status(
@@ -89,39 +56,46 @@ class JsonDocStatusStorage(DocStatusStorage):
89
  ) -> dict[str, DocProcessingStatus]:
90
  """Get all documents with a specific status"""
91
  result = {}
92
- for k, v in self._data.items():
93
- if v["status"] == status.value:
94
- try:
95
- # Make a copy of the data to avoid modifying the original
96
- data = v.copy()
97
- # If content is missing, use content_summary as content
98
- if "content" not in data and "content_summary" in data:
99
- data["content"] = data["content_summary"]
100
- result[k] = DocProcessingStatus(**data)
101
- except KeyError as e:
102
- logger.error(f"Missing required field for document {k}: {e}")
103
- continue
 
104
  return result
105
 
106
  async def index_done_callback(self) -> None:
107
- write_json(self._data, self._file_name)
 
 
108
 
109
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
110
  logger.info(f"Inserting {len(data)} to {self.namespace}")
111
  if not data:
112
  return
113
 
114
- self._data.update(data)
 
115
  await self.index_done_callback()
116
 
117
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
118
- return self._data.get(id)
 
119
 
120
  async def delete(self, doc_ids: list[str]):
121
- for doc_id in doc_ids:
122
- self._data.pop(doc_id, None)
 
123
  await self.index_done_callback()
124
 
125
  async def drop(self) -> None:
126
  """Drop the storage"""
127
- self._data.clear()
 
 
1
  from dataclasses import dataclass
2
  import os
3
  from typing import Any, Union, final
 
4
 
5
  from lightrag.base import (
6
  DocProcessingStatus,
 
12
  logger,
13
  write_json,
14
  )
15
+ from .shared_storage import get_namespace_data, get_storage_lock
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  @final
 
23
  def __post_init__(self):
24
  working_dir = self.global_config["working_dir"]
25
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
26
+ self._storage_lock = get_storage_lock()
27
+ self._data = get_namespace_data(self.namespace)
28
+ with self._storage_lock:
29
+ self._data.update(load_json(self._file_name) or {})
30
+ logger.info(f"Loaded document status storage with {len(self._data)} records")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  async def filter_keys(self, keys: set[str]) -> set[str]:
33
  """Return keys that should be processed (not in storage or not successfully processed)"""
34
+ with self._storage_lock:
35
+ return set(keys) - set(self._data.keys())
36
 
37
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
38
  result: list[dict[str, Any]] = []
39
+ with self._storage_lock:
40
+ for id in ids:
41
+ data = self._data.get(id, None)
42
+ if data:
43
+ result.append(data)
44
  return result
45
 
46
  async def get_status_counts(self) -> dict[str, int]:
47
  """Get counts of documents in each status"""
48
  counts = {status.value: 0 for status in DocStatus}
49
+ with self._storage_lock:
50
+ for doc in self._data.values():
51
+ counts[doc["status"]] += 1
52
  return counts
53
 
54
  async def get_docs_by_status(
 
56
  ) -> dict[str, DocProcessingStatus]:
57
  """Get all documents with a specific status"""
58
  result = {}
59
+ with self._storage_lock:
60
+ for k, v in self._data.items():
61
+ if v["status"] == status.value:
62
+ try:
63
+ # Make a copy of the data to avoid modifying the original
64
+ data = v.copy()
65
+ # If content is missing, use content_summary as content
66
+ if "content" not in data and "content_summary" in data:
67
+ data["content"] = data["content_summary"]
68
+ result[k] = DocProcessingStatus(**data)
69
+ except KeyError as e:
70
+ logger.error(f"Missing required field for document {k}: {e}")
71
+ continue
72
  return result
73
 
74
  async def index_done_callback(self) -> None:
75
+ # 文件写入需要加锁,防止多个进程同时写入导致文件损坏
76
+ with self._storage_lock:
77
+ write_json(self._data, self._file_name)
78
 
79
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
80
  logger.info(f"Inserting {len(data)} to {self.namespace}")
81
  if not data:
82
  return
83
 
84
+ with self._storage_lock:
85
+ self._data.update(data)
86
  await self.index_done_callback()
87
 
88
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
89
+ with self._storage_lock:
90
+ return self._data.get(id)
91
 
92
  async def delete(self, doc_ids: list[str]):
93
+ with self._storage_lock:
94
+ for doc_id in doc_ids:
95
+ self._data.pop(doc_id, None)
96
  await self.index_done_callback()
97
 
98
  async def drop(self) -> None:
99
  """Drop the storage"""
100
+ with self._storage_lock:
101
+ self._data.clear()
lightrag/kg/json_kv_impl.py CHANGED
@@ -1,8 +1,6 @@
1
- import asyncio
2
  import os
3
  from dataclasses import dataclass
4
  from typing import Any, final
5
- import threading
6
 
7
  from lightrag.base import (
8
  BaseKVStorage,
@@ -12,26 +10,7 @@ from lightrag.utils import (
12
  logger,
13
  write_json,
14
  )
15
- from lightrag.api.utils_api import manager as main_process_manager
16
-
17
- # Global variables for shared memory management
18
- _init_lock = threading.Lock()
19
- _manager = None
20
- _shared_kv_data = None
21
-
22
-
23
- def _get_manager():
24
- """Get or create the global manager instance"""
25
- global _manager, _shared_kv_data
26
- with _init_lock:
27
- if _manager is None:
28
- try:
29
- _manager = main_process_manager
30
- _shared_kv_data = _manager.dict()
31
- except Exception as e:
32
- logger.error(f"Failed to initialize shared memory manager: {e}")
33
- raise RuntimeError(f"Shared memory initialization failed: {e}")
34
- return _manager
35
 
36
 
37
  @final
@@ -39,57 +18,49 @@ def _get_manager():
39
  class JsonKVStorage(BaseKVStorage):
40
  def __post_init__(self):
41
  working_dir = self.global_config["working_dir"]
42
- self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
43
- self._lock = asyncio.Lock()
44
-
45
- # Ensure manager is initialized
46
- _get_manager()
47
-
48
- # Get or create namespace data
49
- if self.namespace not in _shared_kv_data:
50
- with _init_lock:
51
- if self.namespace not in _shared_kv_data:
52
- try:
53
- initial_data = load_json(self._file_name) or {}
54
- _shared_kv_data[self.namespace] = initial_data
55
- except Exception as e:
56
- logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
57
- raise RuntimeError(f"Shared data initialization failed: {e}")
58
-
59
- try:
60
- self._data = _shared_kv_data[self.namespace]
61
- logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
62
- except Exception as e:
63
- logger.error(f"Failed to access shared memory: {e}")
64
- raise RuntimeError(f"Cannot access shared memory: {e}")
65
 
66
  async def index_done_callback(self) -> None:
67
- write_json(self._data, self._file_name)
 
 
68
 
69
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
70
- return self._data.get(id)
 
71
 
72
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
73
- return [
74
- (
75
- {k: v for k, v in self._data[id].items()}
76
- if self._data.get(id, None)
77
- else None
78
- )
79
- for id in ids
80
- ]
 
81
 
82
  async def filter_keys(self, keys: set[str]) -> set[str]:
83
- return set(keys) - set(self._data.keys())
 
84
 
85
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
86
  logger.info(f"Inserting {len(data)} to {self.namespace}")
87
  if not data:
88
  return
89
- left_data = {k: v for k, v in data.items() if k not in self._data}
90
- self._data.update(left_data)
 
91
 
92
  async def delete(self, ids: list[str]) -> None:
93
- for doc_id in ids:
94
- self._data.pop(doc_id, None)
 
95
  await self.index_done_callback()
 
 
1
  import os
2
  from dataclasses import dataclass
3
  from typing import Any, final
 
4
 
5
  from lightrag.base import (
6
  BaseKVStorage,
 
10
  logger,
11
  write_json,
12
  )
13
+ from .shared_storage import get_namespace_data, get_storage_lock
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  @final
 
18
  class JsonKVStorage(BaseKVStorage):
19
  def __post_init__(self):
20
  working_dir = self.global_config["working_dir"]
21
+ self._storage_lock = get_storage_lock()
22
+ self._data = get_namespace_data(self.namespace)
23
+ with self._storage_lock:
24
+ if not self._data:
25
+ self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
26
+ self._data: dict[str, Any] = load_json(self._file_name) or {}
27
+ logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
28
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  async def index_done_callback(self) -> None:
31
+ # 文件写入需要加锁,防止多个进程同时写入导致文件损坏
32
+ with self._storage_lock:
33
+ write_json(self._data, self._file_name)
34
 
35
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
36
+ with self._storage_lock:
37
+ return self._data.get(id)
38
 
39
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
40
+ with self._storage_lock:
41
+ return [
42
+ (
43
+ {k: v for k, v in self._data[id].items()}
44
+ if self._data.get(id, None)
45
+ else None
46
+ )
47
+ for id in ids
48
+ ]
49
 
50
  async def filter_keys(self, keys: set[str]) -> set[str]:
51
+ with self._storage_lock:
52
+ return set(keys) - set(self._data.keys())
53
 
54
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
55
  logger.info(f"Inserting {len(data)} to {self.namespace}")
56
  if not data:
57
  return
58
+ with self._storage_lock:
59
+ left_data = {k: v for k, v in data.items() if k not in self._data}
60
+ self._data.update(left_data)
61
 
62
  async def delete(self, ids: list[str]) -> None:
63
+ with self._storage_lock:
64
+ for doc_id in ids:
65
+ self._data.pop(doc_id, None)
66
  await self.index_done_callback()
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -3,50 +3,29 @@ import os
3
  from typing import Any, final
4
  from dataclasses import dataclass
5
  import numpy as np
6
- import threading
7
  import time
8
 
9
  from lightrag.utils import (
10
  logger,
11
  compute_mdhash_id,
12
  )
13
- from lightrag.api.utils_api import manager as main_process_manager
14
  import pipmaster as pm
15
- from lightrag.base import (
16
- BaseVectorStorage,
17
- )
18
 
19
  if not pm.is_installed("nano-vectordb"):
20
  pm.install("nano-vectordb")
21
 
22
  from nano_vectordb import NanoVectorDB
23
 
24
- # Global variables for shared memory management
25
- _init_lock = threading.Lock()
26
- _manager = None
27
- _shared_vector_clients = None
28
-
29
-
30
- def _get_manager():
31
- """Get or create the global manager instance"""
32
- global _manager, _shared_vector_clients
33
- with _init_lock:
34
- if _manager is None:
35
- try:
36
- _manager = main_process_manager
37
- _shared_vector_clients = _manager.dict()
38
- except Exception as e:
39
- logger.error(f"Failed to initialize shared memory manager: {e}")
40
- raise RuntimeError(f"Shared memory initialization failed: {e}")
41
- return _manager
42
-
43
 
44
  @final
45
  @dataclass
46
  class NanoVectorDBStorage(BaseVectorStorage):
47
  def __post_init__(self):
48
  # Initialize lock only for file operations
49
- self._save_lock = asyncio.Lock()
 
50
  # Use global config value if specified, otherwise use default
51
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
52
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -61,28 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage):
61
  )
62
  self._max_batch_size = self.global_config["embedding_batch_num"]
63
 
64
- # Ensure manager is initialized
65
- _get_manager()
66
 
67
- # Get or create namespace client
68
- if self.namespace not in _shared_vector_clients:
69
- with _init_lock:
70
- if self.namespace not in _shared_vector_clients:
71
- try:
72
- client = NanoVectorDB(
73
- self.embedding_func.embedding_dim,
74
- storage_file=self._client_file_name
75
- )
76
- _shared_vector_clients[self.namespace] = client
77
- except Exception as e:
78
- logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}")
79
- raise RuntimeError(f"Vector DB client initialization failed: {e}")
80
 
81
- try:
82
- self._client = _shared_vector_clients[self.namespace]
83
- except Exception as e:
84
- logger.error(f"Failed to access shared memory: {e}")
85
- raise RuntimeError(f"Cannot access shared memory: {e}")
 
 
86
 
87
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
88
  logger.info(f"Inserting {len(data)} to {self.namespace}")
@@ -104,6 +82,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
104
  for i in range(0, len(contents), self._max_batch_size)
105
  ]
106
 
 
107
  embedding_tasks = [self.embedding_func(batch) for batch in batches]
108
  embeddings_list = await asyncio.gather(*embedding_tasks)
109
 
@@ -111,7 +90,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
111
  if len(embeddings) == len(list_data):
112
  for i, d in enumerate(list_data):
113
  d["__vector__"] = embeddings[i]
114
- results = self._client.upsert(datas=list_data)
 
 
115
  return results
116
  else:
117
  # sometimes the embedding is not returned correctly. just log it.
@@ -120,27 +101,32 @@ class NanoVectorDBStorage(BaseVectorStorage):
120
  )
121
 
122
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
 
123
  embedding = await self.embedding_func([query])
124
  embedding = embedding[0]
125
- results = self._client.query(
126
- query=embedding,
127
- top_k=top_k,
128
- better_than_threshold=self.cosine_better_than_threshold,
129
- )
130
- results = [
131
- {
132
- **dp,
133
- "id": dp["__id__"],
134
- "distance": dp["__metrics__"],
135
- "created_at": dp.get("__created_at__"),
136
- }
137
- for dp in results
138
- ]
 
 
 
139
  return results
140
 
141
  @property
142
  def client_storage(self):
143
- return getattr(self._client, "_NanoVectorDB__storage")
 
144
 
145
  async def delete(self, ids: list[str]):
146
  """Delete vectors with specified IDs
@@ -149,8 +135,10 @@ class NanoVectorDBStorage(BaseVectorStorage):
149
  ids: List of vector IDs to be deleted
150
  """
151
  try:
152
- self._client.delete(ids)
153
- logger.info(
 
 
154
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
155
  )
156
  except Exception as e:
@@ -162,35 +150,42 @@ class NanoVectorDBStorage(BaseVectorStorage):
162
  logger.debug(
163
  f"Attempting to delete entity {entity_name} with ID {entity_id}"
164
  )
165
- # Check if the entity exists
166
- if self._client.get([entity_id]):
167
- await self.delete([entity_id])
168
- logger.debug(f"Successfully deleted entity {entity_name}")
169
- else:
170
- logger.debug(f"Entity {entity_name} not found in storage")
 
 
 
171
  except Exception as e:
172
  logger.error(f"Error deleting entity {entity_name}: {e}")
173
 
174
  async def delete_entity_relation(self, entity_name: str) -> None:
175
  try:
176
- relations = [
177
- dp
178
- for dp in self.client_storage["data"]
179
- if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
180
- ]
181
- logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
182
- ids_to_delete = [relation["__id__"] for relation in relations]
183
-
184
- if ids_to_delete:
185
- await self.delete(ids_to_delete)
186
- logger.debug(
187
- f"Deleted {len(ids_to_delete)} relations for {entity_name}"
188
- )
189
- else:
190
- logger.debug(f"No relations found for entity {entity_name}")
 
 
 
191
  except Exception as e:
192
  logger.error(f"Error deleting relations for {entity_name}: {e}")
193
 
194
  async def index_done_callback(self) -> None:
195
- async with self._save_lock:
196
- self._client.save()
 
 
3
  from typing import Any, final
4
  from dataclasses import dataclass
5
  import numpy as np
 
6
  import time
7
 
8
  from lightrag.utils import (
9
  logger,
10
  compute_mdhash_id,
11
  )
 
12
  import pipmaster as pm
13
+ from lightrag.base import BaseVectorStorage
14
+ from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
 
15
 
16
  if not pm.is_installed("nano-vectordb"):
17
  pm.install("nano-vectordb")
18
 
19
  from nano_vectordb import NanoVectorDB
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  @final
23
  @dataclass
24
  class NanoVectorDBStorage(BaseVectorStorage):
25
  def __post_init__(self):
26
  # Initialize lock only for file operations
27
+ self._storage_lock = get_storage_lock()
28
+
29
  # Use global config value if specified, otherwise use default
30
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
31
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
 
40
  )
41
  self._max_batch_size = self.global_config["embedding_batch_num"]
42
 
43
+ self._client = get_namespace_object(self.namespace)
 
44
 
45
+ with self._storage_lock:
46
+ if is_multiprocess:
47
+ if self._client.value is None:
48
+ self._client.value = NanoVectorDB(
49
+ self.embedding_func.embedding_dim, storage_file=self._client_file_name
50
+ )
51
+ else:
52
+ if self._client is None:
53
+ self._client = NanoVectorDB(
54
+ self.embedding_func.embedding_dim, storage_file=self._client_file_name
55
+ )
 
 
56
 
57
+ logger.info(f"Initialized vector DB client for namespace {self.namespace}")
58
+
59
+ def _get_client(self):
60
+ """Get the appropriate client instance based on multiprocess mode"""
61
+ if is_multiprocess:
62
+ return self._client.value
63
+ return self._client
64
 
65
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
66
  logger.info(f"Inserting {len(data)} to {self.namespace}")
 
82
  for i in range(0, len(contents), self._max_batch_size)
83
  ]
84
 
85
+ # Execute embedding outside of lock to avoid long lock times
86
  embedding_tasks = [self.embedding_func(batch) for batch in batches]
87
  embeddings_list = await asyncio.gather(*embedding_tasks)
88
 
 
90
  if len(embeddings) == len(list_data):
91
  for i, d in enumerate(list_data):
92
  d["__vector__"] = embeddings[i]
93
+ with self._storage_lock:
94
+ client = self._get_client()
95
+ results = client.upsert(datas=list_data)
96
  return results
97
  else:
98
  # sometimes the embedding is not returned correctly. just log it.
 
101
  )
102
 
103
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
104
+ # Execute embedding outside of lock to avoid long lock times
105
  embedding = await self.embedding_func([query])
106
  embedding = embedding[0]
107
+
108
+ with self._storage_lock:
109
+ client = self._get_client()
110
+ results = client.query(
111
+ query=embedding,
112
+ top_k=top_k,
113
+ better_than_threshold=self.cosine_better_than_threshold,
114
+ )
115
+ results = [
116
+ {
117
+ **dp,
118
+ "id": dp["__id__"],
119
+ "distance": dp["__metrics__"],
120
+ "created_at": dp.get("__created_at__"),
121
+ }
122
+ for dp in results
123
+ ]
124
  return results
125
 
126
  @property
127
  def client_storage(self):
128
+ client = self._get_client()
129
+ return getattr(client, "_NanoVectorDB__storage")
130
 
131
  async def delete(self, ids: list[str]):
132
  """Delete vectors with specified IDs
 
135
  ids: List of vector IDs to be deleted
136
  """
137
  try:
138
+ with self._storage_lock:
139
+ client = self._get_client()
140
+ client.delete(ids)
141
+ logger.debug(
142
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
143
  )
144
  except Exception as e:
 
150
  logger.debug(
151
  f"Attempting to delete entity {entity_name} with ID {entity_id}"
152
  )
153
+
154
+ with self._storage_lock:
155
+ client = self._get_client()
156
+ # Check if the entity exists
157
+ if client.get([entity_id]):
158
+ client.delete([entity_id])
159
+ logger.debug(f"Successfully deleted entity {entity_name}")
160
+ else:
161
+ logger.debug(f"Entity {entity_name} not found in storage")
162
  except Exception as e:
163
  logger.error(f"Error deleting entity {entity_name}: {e}")
164
 
165
  async def delete_entity_relation(self, entity_name: str) -> None:
166
  try:
167
+ with self._storage_lock:
168
+ client = self._get_client()
169
+ storage = getattr(client, "_NanoVectorDB__storage")
170
+ relations = [
171
+ dp
172
+ for dp in storage["data"]
173
+ if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
174
+ ]
175
+ logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
176
+ ids_to_delete = [relation["__id__"] for relation in relations]
177
+
178
+ if ids_to_delete:
179
+ client.delete(ids_to_delete)
180
+ logger.debug(
181
+ f"Deleted {len(ids_to_delete)} relations for {entity_name}"
182
+ )
183
+ else:
184
+ logger.debug(f"No relations found for entity {entity_name}")
185
  except Exception as e:
186
  logger.error(f"Error deleting relations for {entity_name}: {e}")
187
 
188
  async def index_done_callback(self) -> None:
189
+ with self._storage_lock:
190
+ client = self._get_client()
191
+ client.save()
lightrag/kg/networkx_impl.py CHANGED
@@ -1,18 +1,13 @@
1
  import os
2
  from dataclasses import dataclass
3
  from typing import Any, final
4
- import threading
5
  import numpy as np
6
 
7
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
8
- from lightrag.utils import (
9
- logger,
10
- )
11
- from lightrag.api.utils_api import manager as main_process_manager
12
-
13
- from lightrag.base import (
14
- BaseGraphStorage,
15
- )
16
  import pipmaster as pm
17
 
18
  if not pm.is_installed("networkx"):
@@ -24,25 +19,6 @@ if not pm.is_installed("graspologic"):
24
  import networkx as nx
25
  from graspologic import embed
26
 
27
- # Global variables for shared memory management
28
- _init_lock = threading.Lock()
29
- _manager = None
30
- _shared_graphs = None
31
-
32
-
33
- def _get_manager():
34
- """Get or create the global manager instance"""
35
- global _manager, _shared_graphs
36
- with _init_lock:
37
- if _manager is None:
38
- try:
39
- _manager = main_process_manager
40
- _shared_graphs = _manager.dict()
41
- except Exception as e:
42
- logger.error(f"Failed to initialize shared memory manager: {e}")
43
- raise RuntimeError(f"Shared memory initialization failed: {e}")
44
- return _manager
45
-
46
 
47
  @final
48
  @dataclass
@@ -97,76 +73,98 @@ class NetworkXStorage(BaseGraphStorage):
97
  self._graphml_xml_file = os.path.join(
98
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
99
  )
100
-
101
- # Ensure manager is initialized
102
- _get_manager()
103
-
104
- # Get or create namespace graph
105
- if self.namespace not in _shared_graphs:
106
- with _init_lock:
107
- if self.namespace not in _shared_graphs:
108
- try:
109
- preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
110
- if preloaded_graph is not None:
111
- logger.info(
112
- f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
113
- )
114
- _shared_graphs[self.namespace] = preloaded_graph or nx.Graph()
115
- except Exception as e:
116
- logger.error(f"Failed to initialize graph for namespace {self.namespace}: {e}")
117
- raise RuntimeError(f"Graph initialization failed: {e}")
118
-
119
- try:
120
- self._graph = _shared_graphs[self.namespace]
121
- self._node_embed_algorithms = {
122
  "node2vec": self._node2vec_embed,
123
- }
124
- except Exception as e:
125
- logger.error(f"Failed to access shared memory: {e}")
126
- raise RuntimeError(f"Cannot access shared memory: {e}")
 
 
 
127
 
128
  async def index_done_callback(self) -> None:
129
- NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
 
 
130
 
131
  async def has_node(self, node_id: str) -> bool:
132
- return self._graph.has_node(node_id)
 
 
133
 
134
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
135
- return self._graph.has_edge(source_node_id, target_node_id)
 
 
136
 
137
  async def get_node(self, node_id: str) -> dict[str, str] | None:
138
- return self._graph.nodes.get(node_id)
 
 
139
 
140
  async def node_degree(self, node_id: str) -> int:
141
- return self._graph.degree(node_id)
 
 
142
 
143
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
144
- return self._graph.degree(src_id) + self._graph.degree(tgt_id)
 
 
145
 
146
  async def get_edge(
147
  self, source_node_id: str, target_node_id: str
148
  ) -> dict[str, str] | None:
149
- return self._graph.edges.get((source_node_id, target_node_id))
 
 
150
 
151
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
152
- if self._graph.has_node(source_node_id):
153
- return list(self._graph.edges(source_node_id))
154
- return None
 
 
155
 
156
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
157
- self._graph.add_node(node_id, **node_data)
 
 
158
 
159
  async def upsert_edge(
160
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
161
  ) -> None:
162
- self._graph.add_edge(source_node_id, target_node_id, **edge_data)
 
 
163
 
164
  async def delete_node(self, node_id: str) -> None:
165
- if self._graph.has_node(node_id):
166
- self._graph.remove_node(node_id)
167
- logger.info(f"Node {node_id} deleted from the graph.")
168
- else:
169
- logger.warning(f"Node {node_id} not found in the graph for deletion.")
 
 
170
 
171
  async def embed_nodes(
172
  self, algorithm: str
@@ -175,14 +173,15 @@ class NetworkXStorage(BaseGraphStorage):
175
  raise ValueError(f"Node embedding algorithm {algorithm} not supported")
176
  return await self._node_embed_algorithms[algorithm]()
177
 
178
- # @TODO: NOT USED
179
  async def _node2vec_embed(self):
180
- embeddings, nodes = embed.node2vec_embed(
181
- self._graph,
182
- **self.global_config["node2vec_params"],
183
- )
184
-
185
- nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
 
186
  return embeddings, nodes_ids
187
 
188
  def remove_nodes(self, nodes: list[str]):
@@ -191,9 +190,11 @@ class NetworkXStorage(BaseGraphStorage):
191
  Args:
192
  nodes: List of node IDs to be deleted
193
  """
194
- for node in nodes:
195
- if self._graph.has_node(node):
196
- self._graph.remove_node(node)
 
 
197
 
198
  def remove_edges(self, edges: list[tuple[str, str]]):
199
  """Delete multiple edges
@@ -201,9 +202,11 @@ class NetworkXStorage(BaseGraphStorage):
201
  Args:
202
  edges: List of edges to be deleted, each edge is a (source, target) tuple
203
  """
204
- for source, target in edges:
205
- if self._graph.has_edge(source, target):
206
- self._graph.remove_edge(source, target)
 
 
207
 
208
  async def get_all_labels(self) -> list[str]:
209
  """
@@ -211,9 +214,11 @@ class NetworkXStorage(BaseGraphStorage):
211
  Returns:
212
  [label1, label2, ...] # Alphabetically sorted label list
213
  """
214
- labels = set()
215
- for node in self._graph.nodes():
216
- labels.add(str(node)) # Add node id as a label
 
 
217
 
218
  # Return sorted list
219
  return sorted(list(labels))
@@ -235,87 +240,86 @@ class NetworkXStorage(BaseGraphStorage):
235
  seen_nodes = set()
236
  seen_edges = set()
237
 
238
- # Handle special case for "*" label
239
- if node_label == "*":
240
- # For "*", return the entire graph including all nodes and edges
241
- subgraph = (
242
- self._graph.copy()
243
- ) # Create a copy to avoid modifying the original graph
244
- else:
245
- # Find nodes with matching node id (partial match)
246
- nodes_to_explore = []
247
- for n, attr in self._graph.nodes(data=True):
248
- if node_label in str(n): # Use partial matching
249
- nodes_to_explore.append(n)
250
-
251
- if not nodes_to_explore:
252
- logger.warning(f"No nodes found with label {node_label}")
253
- return result
254
-
255
- # Get subgraph using ego_graph
256
- subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth)
257
-
258
- # Check if number of nodes exceeds max_graph_nodes
259
- max_graph_nodes = 500
260
- if len(subgraph.nodes()) > max_graph_nodes:
261
- origin_nodes = len(subgraph.nodes())
262
- node_degrees = dict(subgraph.degree())
263
- top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
264
- :max_graph_nodes
265
- ]
266
- top_node_ids = [node[0] for node in top_nodes]
267
- # Create new subgraph with only top nodes
268
- subgraph = subgraph.subgraph(top_node_ids)
269
- logger.info(
270
- f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
271
- )
 
272
 
273
- # Add nodes to result
274
- for node in subgraph.nodes():
275
- if str(node) in seen_nodes:
276
- continue
277
-
278
- node_data = dict(subgraph.nodes[node])
279
- # Get entity_type as labels
280
- labels = []
281
- if "entity_type" in node_data:
282
- if isinstance(node_data["entity_type"], list):
283
- labels.extend(node_data["entity_type"])
284
- else:
285
- labels.append(node_data["entity_type"])
286
-
287
- # Create node with properties
288
- node_properties = {k: v for k, v in node_data.items()}
289
-
290
- result.nodes.append(
291
- KnowledgeGraphNode(
292
- id=str(node), labels=[str(node)], properties=node_properties
 
293
  )
294
- )
295
- seen_nodes.add(str(node))
296
-
297
- # Add edges to result
298
- for edge in subgraph.edges():
299
- source, target = edge
300
- edge_id = f"{source}-{target}"
301
- if edge_id in seen_edges:
302
- continue
303
-
304
- edge_data = dict(subgraph.edges[edge])
305
-
306
- # Create edge with complete information
307
- result.edges.append(
308
- KnowledgeGraphEdge(
309
- id=edge_id,
310
- type="DIRECTED",
311
- source=str(source),
312
- target=str(target),
313
- properties=edge_data,
314
  )
315
- )
316
- seen_edges.add(edge_id)
317
-
318
- # logger.info(result.edges)
319
 
320
  logger.info(
321
  f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
 
1
  import os
2
  from dataclasses import dataclass
3
  from typing import Any, final
 
4
  import numpy as np
5
 
6
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
7
+ from lightrag.utils import logger
8
+ from lightrag.base import BaseGraphStorage
9
+ from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
10
+
 
 
 
 
11
  import pipmaster as pm
12
 
13
  if not pm.is_installed("networkx"):
 
19
  import networkx as nx
20
  from graspologic import embed
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  @final
24
  @dataclass
 
73
  self._graphml_xml_file = os.path.join(
74
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
75
  )
76
+ self._storage_lock = get_storage_lock()
77
+ self._graph = get_namespace_object(self.namespace)
78
+ with self._storage_lock:
79
+ if is_multiprocess:
80
+ if self._graph.value is None:
81
+ preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
82
+ self._graph.value = preloaded_graph or nx.Graph()
83
+ logger.info(
84
+ f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
85
+ )
86
+ else:
87
+ if self._graph is None:
88
+ preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
89
+ self._graph = preloaded_graph or nx.Graph()
90
+ logger.info(
91
+ f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
92
+ )
93
+
94
+ self._node_embed_algorithms = {
 
 
 
95
  "node2vec": self._node2vec_embed,
96
+ }
97
+
98
+ def _get_graph(self):
99
+ """Get the appropriate graph instance based on multiprocess mode"""
100
+ if is_multiprocess:
101
+ return self._graph.value
102
+ return self._graph
103
 
104
  async def index_done_callback(self) -> None:
105
+ with self._storage_lock:
106
+ graph = self._get_graph()
107
+ NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file)
108
 
109
  async def has_node(self, node_id: str) -> bool:
110
+ with self._storage_lock:
111
+ graph = self._get_graph()
112
+ return graph.has_node(node_id)
113
 
114
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
115
+ with self._storage_lock:
116
+ graph = self._get_graph()
117
+ return graph.has_edge(source_node_id, target_node_id)
118
 
119
  async def get_node(self, node_id: str) -> dict[str, str] | None:
120
+ with self._storage_lock:
121
+ graph = self._get_graph()
122
+ return graph.nodes.get(node_id)
123
 
124
  async def node_degree(self, node_id: str) -> int:
125
+ with self._storage_lock:
126
+ graph = self._get_graph()
127
+ return graph.degree(node_id)
128
 
129
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
130
+ with self._storage_lock:
131
+ graph = self._get_graph()
132
+ return graph.degree(src_id) + graph.degree(tgt_id)
133
 
134
  async def get_edge(
135
  self, source_node_id: str, target_node_id: str
136
  ) -> dict[str, str] | None:
137
+ with self._storage_lock:
138
+ graph = self._get_graph()
139
+ return graph.edges.get((source_node_id, target_node_id))
140
 
141
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
142
+ with self._storage_lock:
143
+ graph = self._get_graph()
144
+ if graph.has_node(source_node_id):
145
+ return list(graph.edges(source_node_id))
146
+ return None
147
 
148
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
149
+ with self._storage_lock:
150
+ graph = self._get_graph()
151
+ graph.add_node(node_id, **node_data)
152
 
153
  async def upsert_edge(
154
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
155
  ) -> None:
156
+ with self._storage_lock:
157
+ graph = self._get_graph()
158
+ graph.add_edge(source_node_id, target_node_id, **edge_data)
159
 
160
  async def delete_node(self, node_id: str) -> None:
161
+ with self._storage_lock:
162
+ graph = self._get_graph()
163
+ if graph.has_node(node_id):
164
+ graph.remove_node(node_id)
165
+ logger.debug(f"Node {node_id} deleted from the graph.")
166
+ else:
167
+ logger.warning(f"Node {node_id} not found in the graph for deletion.")
168
 
169
  async def embed_nodes(
170
  self, algorithm: str
 
173
  raise ValueError(f"Node embedding algorithm {algorithm} not supported")
174
  return await self._node_embed_algorithms[algorithm]()
175
 
176
+ # TODO: NOT USED
177
  async def _node2vec_embed(self):
178
+ with self._storage_lock:
179
+ graph = self._get_graph()
180
+ embeddings, nodes = embed.node2vec_embed(
181
+ graph,
182
+ **self.global_config["node2vec_params"],
183
+ )
184
+ nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
185
  return embeddings, nodes_ids
186
 
187
  def remove_nodes(self, nodes: list[str]):
 
190
  Args:
191
  nodes: List of node IDs to be deleted
192
  """
193
+ with self._storage_lock:
194
+ graph = self._get_graph()
195
+ for node in nodes:
196
+ if graph.has_node(node):
197
+ graph.remove_node(node)
198
 
199
  def remove_edges(self, edges: list[tuple[str, str]]):
200
  """Delete multiple edges
 
202
  Args:
203
  edges: List of edges to be deleted, each edge is a (source, target) tuple
204
  """
205
+ with self._storage_lock:
206
+ graph = self._get_graph()
207
+ for source, target in edges:
208
+ if graph.has_edge(source, target):
209
+ graph.remove_edge(source, target)
210
 
211
  async def get_all_labels(self) -> list[str]:
212
  """
 
214
  Returns:
215
  [label1, label2, ...] # Alphabetically sorted label list
216
  """
217
+ with self._storage_lock:
218
+ graph = self._get_graph()
219
+ labels = set()
220
+ for node in graph.nodes():
221
+ labels.add(str(node)) # Add node id as a label
222
 
223
  # Return sorted list
224
  return sorted(list(labels))
 
240
  seen_nodes = set()
241
  seen_edges = set()
242
 
243
+ with self._storage_lock:
244
+ graph = self._get_graph()
245
+
246
+ # Handle special case for "*" label
247
+ if node_label == "*":
248
+ # For "*", return the entire graph including all nodes and edges
249
+ subgraph = graph.copy() # Create a copy to avoid modifying the original graph
250
+ else:
251
+ # Find nodes with matching node id (partial match)
252
+ nodes_to_explore = []
253
+ for n, attr in graph.nodes(data=True):
254
+ if node_label in str(n): # Use partial matching
255
+ nodes_to_explore.append(n)
256
+
257
+ if not nodes_to_explore:
258
+ logger.warning(f"No nodes found with label {node_label}")
259
+ return result
260
+
261
+ # Get subgraph using ego_graph
262
+ subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
263
+
264
+ # Check if number of nodes exceeds max_graph_nodes
265
+ max_graph_nodes = 500
266
+ if len(subgraph.nodes()) > max_graph_nodes:
267
+ origin_nodes = len(subgraph.nodes())
268
+ node_degrees = dict(subgraph.degree())
269
+ top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
270
+ :max_graph_nodes
271
+ ]
272
+ top_node_ids = [node[0] for node in top_nodes]
273
+ # Create new subgraph with only top nodes
274
+ subgraph = subgraph.subgraph(top_node_ids)
275
+ logger.info(
276
+ f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
277
+ )
278
 
279
+ # Add nodes to result
280
+ for node in subgraph.nodes():
281
+ if str(node) in seen_nodes:
282
+ continue
283
+
284
+ node_data = dict(subgraph.nodes[node])
285
+ # Get entity_type as labels
286
+ labels = []
287
+ if "entity_type" in node_data:
288
+ if isinstance(node_data["entity_type"], list):
289
+ labels.extend(node_data["entity_type"])
290
+ else:
291
+ labels.append(node_data["entity_type"])
292
+
293
+ # Create node with properties
294
+ node_properties = {k: v for k, v in node_data.items()}
295
+
296
+ result.nodes.append(
297
+ KnowledgeGraphNode(
298
+ id=str(node), labels=[str(node)], properties=node_properties
299
+ )
300
  )
301
+ seen_nodes.add(str(node))
302
+
303
+ # Add edges to result
304
+ for edge in subgraph.edges():
305
+ source, target = edge
306
+ edge_id = f"{source}-{target}"
307
+ if edge_id in seen_edges:
308
+ continue
309
+
310
+ edge_data = dict(subgraph.edges[edge])
311
+
312
+ # Create edge with complete information
313
+ result.edges.append(
314
+ KnowledgeGraphEdge(
315
+ id=edge_id,
316
+ type="DIRECTED",
317
+ source=str(source),
318
+ target=str(target),
319
+ properties=edge_data,
320
+ )
321
  )
322
+ seen_edges.add(edge_id)
 
 
 
323
 
324
  logger.info(
325
  f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
lightrag/kg/shared_storage.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing.synchronize import Lock as ProcessLock
2
+ from threading import Lock as ThreadLock
3
+ from multiprocessing import Manager
4
+ from typing import Any, Dict, Optional, Union
5
+
6
+ # 定义类型变量
7
+ LockType = Union[ProcessLock, ThreadLock]
8
+
9
+ # 全局变量
10
+ _shared_data: Optional[Dict[str, Any]] = None
11
+ _namespace_objects: Optional[Dict[str, Any]] = None
12
+ _global_lock: Optional[LockType] = None
13
+ is_multiprocess = False
14
+ manager = None
15
+
16
+ def initialize_manager():
17
+ """Initialize manager, only for multiple processes where workers > 1"""
18
+ global manager
19
+ if manager is None:
20
+ manager = Manager()
21
+
22
+ def _get_global_lock() -> LockType:
23
+ global _global_lock, is_multiprocess
24
+
25
+ if _global_lock is None:
26
+ if is_multiprocess:
27
+ _global_lock = manager.Lock()
28
+ else:
29
+ _global_lock = ThreadLock()
30
+
31
+ return _global_lock
32
+
33
+ def get_storage_lock() -> LockType:
34
+ """return storage lock for data consistency"""
35
+ return _get_global_lock()
36
+
37
+ def get_scan_lock() -> LockType:
38
+ """return scan_progress lock for data consistency"""
39
+ return get_storage_lock()
40
+
41
+ def get_shared_data() -> Dict[str, Any]:
42
+ """
43
+ return shared data for all storage types
44
+ create mult-process save share data only if need for better performance
45
+ """
46
+ global _shared_data, is_multiprocess
47
+
48
+ if _shared_data is None:
49
+ lock = _get_global_lock()
50
+ with lock:
51
+ if _shared_data is None:
52
+ if is_multiprocess:
53
+ _shared_data = manager.dict()
54
+ else:
55
+ _shared_data = {}
56
+
57
+ return _shared_data
58
+
59
+ def get_namespace_object(namespace: str) -> Any:
60
+ """Get an object for specific namespace"""
61
+ global _namespace_objects, is_multiprocess
62
+
63
+ if _namespace_objects is None:
64
+ lock = _get_global_lock()
65
+ with lock:
66
+ if _namespace_objects is None:
67
+ _namespace_objects = {}
68
+
69
+ if namespace not in _namespace_objects:
70
+ lock = _get_global_lock()
71
+ with lock:
72
+ if namespace not in _namespace_objects:
73
+ if is_multiprocess:
74
+ _namespace_objects[namespace] = manager.Value('O', None)
75
+ else:
76
+ _namespace_objects[namespace] = None
77
+
78
+ return _namespace_objects[namespace]
79
+
80
+ def get_namespace_data(namespace: str) -> Dict[str, Any]:
81
+ """get storage space for specific storage type(namespace)"""
82
+ shared_data = get_shared_data()
83
+ lock = _get_global_lock()
84
+
85
+ if namespace not in shared_data:
86
+ with lock:
87
+ if namespace not in shared_data:
88
+ shared_data[namespace] = {}
89
+
90
+ return shared_data[namespace]
91
+
92
+ def get_scan_progress() -> Dict[str, Any]:
93
+ """get storage space for document scanning progress data"""
94
+ return get_namespace_data('scan_progress')
lightrag/lightrag.py CHANGED
@@ -266,13 +266,7 @@ class LightRAG:
266
 
267
  _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
268
 
269
- def __post_init__(self):
270
- # Initialize manager if needed
271
- from lightrag.api.utils_api import manager, initialize_manager
272
- if manager is None:
273
- initialize_manager()
274
- logger.info("Initialized manager for single process mode")
275
-
276
  os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
277
  set_logger(self.log_file_path, self.log_level)
278
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
 
266
 
267
  _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
268
 
269
+ def __post_init__(self):
 
 
 
 
 
 
270
  os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
271
  set_logger(self.log_file_path, self.log_level)
272
  logger.info(f"Logger initialized for working directory: {self.working_dir}")