yangdx
commited on
Commit
·
5d78930
1
Parent(s):
d2b7a97
Refactor storage implementations to support both single and multi-process modes
Browse files• Add shared storage management module
• Support process/thread lock based on mode
- lightrag/api/lightrag_server.py +6 -3
- lightrag/api/routers/document_routes.py +49 -17
- lightrag/api/utils_api.py +0 -61
- lightrag/kg/faiss_impl.py +157 -178
- lightrag/kg/json_doc_status_impl.py +41 -67
- lightrag/kg/json_kv_impl.py +31 -60
- lightrag/kg/nano_vector_db_impl.py +81 -86
- lightrag/kg/networkx_impl.py +169 -165
- lightrag/kg/shared_storage.py +94 -0
- lightrag/lightrag.py +1 -7
lightrag/api/lightrag_server.py
CHANGED
@@ -406,9 +406,6 @@ def create_app(args):
|
|
406 |
|
407 |
def get_application():
|
408 |
"""Factory function for creating the FastAPI application"""
|
409 |
-
from .utils_api import initialize_manager
|
410 |
-
initialize_manager()
|
411 |
-
|
412 |
# Get args from environment variable
|
413 |
args_json = os.environ.get('LIGHTRAG_ARGS')
|
414 |
if not args_json:
|
@@ -428,6 +425,12 @@ def main():
|
|
428 |
# Save args to environment variable for child processes
|
429 |
os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args))
|
430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
# Configure uvicorn logging
|
432 |
logging.config.dictConfig({
|
433 |
"version": 1,
|
|
|
406 |
|
407 |
def get_application():
|
408 |
"""Factory function for creating the FastAPI application"""
|
|
|
|
|
|
|
409 |
# Get args from environment variable
|
410 |
args_json = os.environ.get('LIGHTRAG_ARGS')
|
411 |
if not args_json:
|
|
|
425 |
# Save args to environment variable for child processes
|
426 |
os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args))
|
427 |
|
428 |
+
if args.workers > 1:
|
429 |
+
from lightrag.kg.shared_storage import initialize_manager
|
430 |
+
initialize_manager()
|
431 |
+
import lightrag.kg.shared_storage as shared_storage
|
432 |
+
shared_storage.is_multiprocess = True
|
433 |
+
|
434 |
# Configure uvicorn logging
|
435 |
logging.config.dictConfig({
|
436 |
"version": 1,
|
lightrag/api/routers/document_routes.py
CHANGED
@@ -18,12 +18,10 @@ from pydantic import BaseModel, Field, field_validator
|
|
18 |
|
19 |
from lightrag import LightRAG
|
20 |
from lightrag.base import DocProcessingStatus, DocStatus
|
21 |
-
from ..utils_api import
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
update_scan_progress,
|
26 |
-
reset_scan_progress,
|
27 |
)
|
28 |
|
29 |
|
@@ -378,23 +376,51 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
|
378 |
|
379 |
async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
380 |
"""Background task to scan and index documents"""
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
try:
|
388 |
new_files = doc_manager.scan_directory_for_new_files()
|
389 |
total_files = len(new_files)
|
390 |
-
|
|
|
|
|
|
|
|
|
|
|
391 |
|
392 |
logging.info(f"Found {total_files} new files to index.")
|
393 |
for idx, file_path in enumerate(new_files):
|
394 |
try:
|
395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
await pipeline_index_file(rag, file_path)
|
397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
|
399 |
except Exception as e:
|
400 |
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
@@ -402,7 +428,13 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
|
402 |
except Exception as e:
|
403 |
logging.error(f"Error during scanning process: {str(e)}")
|
404 |
finally:
|
405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
|
407 |
|
408 |
def create_document_routes(
|
@@ -427,7 +459,7 @@ def create_document_routes(
|
|
427 |
return {"status": "scanning_started"}
|
428 |
|
429 |
@router.get("/scan-progress")
|
430 |
-
async def
|
431 |
"""
|
432 |
Get the current progress of the document scanning process.
|
433 |
|
@@ -439,7 +471,7 @@ def create_document_routes(
|
|
439 |
- total_files: Total number of files to process
|
440 |
- progress: Percentage of completion
|
441 |
"""
|
442 |
-
return dict(
|
443 |
|
444 |
@router.post("/upload", dependencies=[Depends(optional_api_key)])
|
445 |
async def upload_to_input_dir(
|
|
|
18 |
|
19 |
from lightrag import LightRAG
|
20 |
from lightrag.base import DocProcessingStatus, DocStatus
|
21 |
+
from ..utils_api import get_api_key_dependency
|
22 |
+
from lightrag.kg.shared_storage import (
|
23 |
+
get_scan_progress,
|
24 |
+
get_scan_lock,
|
|
|
|
|
25 |
)
|
26 |
|
27 |
|
|
|
376 |
|
377 |
async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
378 |
"""Background task to scan and index documents"""
|
379 |
+
scan_progress = get_scan_progress()
|
380 |
+
scan_lock = get_scan_lock()
|
381 |
+
|
382 |
+
with scan_lock:
|
383 |
+
if scan_progress["is_scanning"]:
|
384 |
+
ASCIIColors.info(
|
385 |
+
"Skip document scanning(another scanning is active)"
|
386 |
+
)
|
387 |
+
return
|
388 |
+
scan_progress.update({
|
389 |
+
"is_scanning": True,
|
390 |
+
"current_file": "",
|
391 |
+
"indexed_count": 0,
|
392 |
+
"total_files": 0,
|
393 |
+
"progress": 0,
|
394 |
+
})
|
395 |
|
396 |
try:
|
397 |
new_files = doc_manager.scan_directory_for_new_files()
|
398 |
total_files = len(new_files)
|
399 |
+
scan_progress.update({
|
400 |
+
"current_file": "",
|
401 |
+
"total_files": total_files,
|
402 |
+
"indexed_count": 0,
|
403 |
+
"progress": 0,
|
404 |
+
})
|
405 |
|
406 |
logging.info(f"Found {total_files} new files to index.")
|
407 |
for idx, file_path in enumerate(new_files):
|
408 |
try:
|
409 |
+
progress = (idx / total_files * 100) if total_files > 0 else 0
|
410 |
+
scan_progress.update({
|
411 |
+
"current_file": os.path.basename(file_path),
|
412 |
+
"indexed_count": idx,
|
413 |
+
"progress": progress,
|
414 |
+
})
|
415 |
+
|
416 |
await pipeline_index_file(rag, file_path)
|
417 |
+
|
418 |
+
progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0
|
419 |
+
scan_progress.update({
|
420 |
+
"current_file": os.path.basename(file_path),
|
421 |
+
"indexed_count": idx + 1,
|
422 |
+
"progress": progress,
|
423 |
+
})
|
424 |
|
425 |
except Exception as e:
|
426 |
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
|
|
428 |
except Exception as e:
|
429 |
logging.error(f"Error during scanning process: {str(e)}")
|
430 |
finally:
|
431 |
+
scan_progress.update({
|
432 |
+
"is_scanning": False,
|
433 |
+
"current_file": "",
|
434 |
+
"indexed_count": 0,
|
435 |
+
"total_files": 0,
|
436 |
+
"progress": 0,
|
437 |
+
})
|
438 |
|
439 |
|
440 |
def create_document_routes(
|
|
|
459 |
return {"status": "scanning_started"}
|
460 |
|
461 |
@router.get("/scan-progress")
|
462 |
+
async def get_scanning_progress():
|
463 |
"""
|
464 |
Get the current progress of the document scanning process.
|
465 |
|
|
|
471 |
- total_files: Total number of files to process
|
472 |
- progress: Percentage of completion
|
473 |
"""
|
474 |
+
return dict(get_scan_progress())
|
475 |
|
476 |
@router.post("/upload", dependencies=[Depends(optional_api_key)])
|
477 |
async def upload_to_input_dir(
|
lightrag/api/utils_api.py
CHANGED
@@ -6,7 +6,6 @@ import os
|
|
6 |
import argparse
|
7 |
from typing import Optional
|
8 |
import sys
|
9 |
-
from multiprocessing import Manager
|
10 |
from ascii_colors import ASCIIColors
|
11 |
from lightrag.api import __api_version__
|
12 |
from fastapi import HTTPException, Security
|
@@ -17,66 +16,6 @@ from starlette.status import HTTP_403_FORBIDDEN
|
|
17 |
# Load environment variables
|
18 |
load_dotenv(override=True)
|
19 |
|
20 |
-
# Global variables for manager and shared state
|
21 |
-
manager = None
|
22 |
-
scan_progress = None
|
23 |
-
scan_lock = None
|
24 |
-
|
25 |
-
def initialize_manager():
|
26 |
-
"""Initialize manager and shared state for cross-process communication"""
|
27 |
-
global manager, scan_progress, scan_lock
|
28 |
-
if manager is None:
|
29 |
-
manager = Manager()
|
30 |
-
scan_progress = manager.dict({
|
31 |
-
"is_scanning": False,
|
32 |
-
"current_file": "",
|
33 |
-
"indexed_count": 0,
|
34 |
-
"total_files": 0,
|
35 |
-
"progress": 0,
|
36 |
-
})
|
37 |
-
scan_lock = manager.Lock()
|
38 |
-
|
39 |
-
def update_scan_progress_if_not_scanning():
|
40 |
-
"""
|
41 |
-
Atomically check if scanning is not in progress and update scan_progress if it's not.
|
42 |
-
Returns True if the update was successful, False if scanning was already in progress.
|
43 |
-
"""
|
44 |
-
with scan_lock:
|
45 |
-
if not scan_progress["is_scanning"]:
|
46 |
-
scan_progress.update({
|
47 |
-
"is_scanning": True,
|
48 |
-
"current_file": "",
|
49 |
-
"indexed_count": 0,
|
50 |
-
"total_files": 0,
|
51 |
-
"progress": 0,
|
52 |
-
})
|
53 |
-
return True
|
54 |
-
return False
|
55 |
-
|
56 |
-
def update_scan_progress(current_file: str, total_files: int, indexed_count: int):
|
57 |
-
"""
|
58 |
-
Atomically update scan progress information.
|
59 |
-
"""
|
60 |
-
progress = (indexed_count / total_files * 100) if total_files > 0 else 0
|
61 |
-
scan_progress.update({
|
62 |
-
"current_file": current_file,
|
63 |
-
"indexed_count": indexed_count,
|
64 |
-
"total_files": total_files,
|
65 |
-
"progress": progress,
|
66 |
-
})
|
67 |
-
|
68 |
-
def reset_scan_progress():
|
69 |
-
"""
|
70 |
-
Atomically reset scan progress to initial state.
|
71 |
-
"""
|
72 |
-
scan_progress.update({
|
73 |
-
"is_scanning": False,
|
74 |
-
"current_file": "",
|
75 |
-
"indexed_count": 0,
|
76 |
-
"total_files": 0,
|
77 |
-
"progress": 0,
|
78 |
-
})
|
79 |
-
|
80 |
|
81 |
class OllamaServerInfos:
|
82 |
# Constants for emulated Ollama model information
|
|
|
6 |
import argparse
|
7 |
from typing import Optional
|
8 |
import sys
|
|
|
9 |
from ascii_colors import ASCIIColors
|
10 |
from lightrag.api import __api_version__
|
11 |
from fastapi import HTTPException, Security
|
|
|
16 |
# Load environment variables
|
17 |
load_dotenv(override=True)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
class OllamaServerInfos:
|
21 |
# Constants for emulated Ollama model information
|
lightrag/kg/faiss_impl.py
CHANGED
@@ -2,48 +2,21 @@ import os
|
|
2 |
import time
|
3 |
import asyncio
|
4 |
from typing import Any, final
|
5 |
-
import threading
|
6 |
import json
|
7 |
import numpy as np
|
8 |
|
9 |
from dataclasses import dataclass
|
10 |
import pipmaster as pm
|
11 |
-
from lightrag.api.utils_api import manager as main_process_manager
|
12 |
|
13 |
-
from lightrag.utils import
|
14 |
-
|
15 |
-
|
16 |
-
)
|
17 |
-
from lightrag.base import (
|
18 |
-
BaseVectorStorage,
|
19 |
-
)
|
20 |
|
21 |
if not pm.is_installed("faiss"):
|
22 |
pm.install("faiss")
|
23 |
|
24 |
import faiss # type: ignore
|
25 |
|
26 |
-
# Global variables for shared memory management
|
27 |
-
_init_lock = threading.Lock()
|
28 |
-
_manager = None
|
29 |
-
_shared_indices = None
|
30 |
-
_shared_meta = None
|
31 |
-
|
32 |
-
|
33 |
-
def _get_manager():
|
34 |
-
"""Get or create the global manager instance"""
|
35 |
-
global _manager, _shared_indices, _shared_meta
|
36 |
-
with _init_lock:
|
37 |
-
if _manager is None:
|
38 |
-
try:
|
39 |
-
_manager = main_process_manager
|
40 |
-
_shared_indices = _manager.dict()
|
41 |
-
_shared_meta = _manager.dict()
|
42 |
-
except Exception as e:
|
43 |
-
logger.error(f"Failed to initialize shared memory manager: {e}")
|
44 |
-
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
45 |
-
return _manager
|
46 |
-
|
47 |
|
48 |
@final
|
49 |
@dataclass
|
@@ -72,48 +45,29 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
72 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
73 |
# Embedding dimension (e.g. 768) must match your embedding function
|
74 |
self._dim = self.embedding_func.embedding_dim
|
|
|
75 |
|
76 |
-
|
77 |
-
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}"
|
98 |
-
)
|
99 |
-
except Exception as e:
|
100 |
-
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
101 |
-
logger.warning("Starting with an empty Faiss index.")
|
102 |
-
index = faiss.IndexFlatIP(self._dim)
|
103 |
-
meta = {}
|
104 |
-
|
105 |
-
_shared_indices[self.namespace] = index
|
106 |
-
_shared_meta[self.namespace] = meta
|
107 |
-
except Exception as e:
|
108 |
-
logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}")
|
109 |
-
raise RuntimeError(f"Faiss index initialization failed: {e}")
|
110 |
-
|
111 |
-
try:
|
112 |
-
self._index = _shared_indices[self.namespace]
|
113 |
-
self._id_to_meta = _shared_meta[self.namespace]
|
114 |
-
except Exception as e:
|
115 |
-
logger.error(f"Failed to access shared memory: {e}")
|
116 |
-
raise RuntimeError(f"Cannot access shared memory: {e}")
|
117 |
|
118 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
119 |
"""
|
@@ -168,32 +122,36 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
168 |
# Normalize embeddings for cosine similarity (in-place)
|
169 |
faiss.normalize_L2(embeddings)
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
197 |
|
198 |
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
199 |
"""
|
@@ -209,54 +167,57 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
209 |
)
|
210 |
|
211 |
# Perform the similarity search
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
238 |
|
239 |
@property
|
240 |
def client_storage(self):
|
241 |
# Return whatever structure LightRAG might need for debugging
|
242 |
-
|
|
|
243 |
|
244 |
async def delete(self, ids: list[str]):
|
245 |
"""
|
246 |
Delete vectors for the provided custom IDs.
|
247 |
"""
|
248 |
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
260 |
|
261 |
async def delete_entity(self, entity_name: str) -> None:
|
262 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
@@ -268,18 +229,20 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
268 |
Delete relations for a given entity by scanning metadata.
|
269 |
"""
|
270 |
logger.debug(f"Searching relations for entity {entity_name}")
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
275 |
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
|
281 |
async def index_done_callback(self) -> None:
|
282 |
-
self.
|
|
|
283 |
|
284 |
# --------------------------------------------------------------------------------
|
285 |
# Internal helper methods
|
@@ -289,10 +252,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
289 |
"""
|
290 |
Return the Faiss internal ID for a given custom ID, or None if not found.
|
291 |
"""
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
296 |
|
297 |
def _remove_faiss_ids(self, fid_list):
|
298 |
"""
|
@@ -300,39 +264,45 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
300 |
Because IndexFlatIP doesn't support 'removals',
|
301 |
we rebuild the index excluding those vectors.
|
302 |
"""
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
def _save_faiss_index(self):
|
322 |
"""
|
323 |
Save the current Faiss index + metadata to disk so it can persist across runs.
|
324 |
"""
|
325 |
-
|
|
|
326 |
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
|
334 |
-
|
335 |
-
|
336 |
|
337 |
def _load_faiss_index(self):
|
338 |
"""
|
@@ -345,22 +315,31 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
345 |
|
346 |
try:
|
347 |
# Load the Faiss index
|
348 |
-
|
|
|
|
|
|
|
|
|
|
|
349 |
# Load metadata
|
350 |
with open(self._meta_file, "r", encoding="utf-8") as f:
|
351 |
stored_dict = json.load(f)
|
352 |
|
353 |
# Convert string keys back to int
|
354 |
-
self._id_to_meta
|
355 |
for fid_str, meta in stored_dict.items():
|
356 |
fid = int(fid_str)
|
357 |
self._id_to_meta[fid] = meta
|
358 |
|
359 |
logger.info(
|
360 |
-
f"Faiss index loaded with {
|
361 |
)
|
362 |
except Exception as e:
|
363 |
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
364 |
logger.warning("Starting with an empty Faiss index.")
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
|
|
|
2 |
import time
|
3 |
import asyncio
|
4 |
from typing import Any, final
|
|
|
5 |
import json
|
6 |
import numpy as np
|
7 |
|
8 |
from dataclasses import dataclass
|
9 |
import pipmaster as pm
|
|
|
10 |
|
11 |
+
from lightrag.utils import logger,compute_mdhash_id
|
12 |
+
from lightrag.base import BaseVectorStorage
|
13 |
+
from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess
|
|
|
|
|
|
|
|
|
14 |
|
15 |
if not pm.is_installed("faiss"):
|
16 |
pm.install("faiss")
|
17 |
|
18 |
import faiss # type: ignore
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
@final
|
22 |
@dataclass
|
|
|
45 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
46 |
# Embedding dimension (e.g. 768) must match your embedding function
|
47 |
self._dim = self.embedding_func.embedding_dim
|
48 |
+
self._storage_lock = get_storage_lock()
|
49 |
|
50 |
+
self._index = get_namespace_object('faiss_indices')
|
51 |
+
self._id_to_meta = get_namespace_data('faiss_meta')
|
52 |
|
53 |
+
with self._storage_lock:
|
54 |
+
if is_multiprocess:
|
55 |
+
if self._index.value is None:
|
56 |
+
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
57 |
+
# If you have a large number of vectors, you might want IVF or other indexes.
|
58 |
+
# For demonstration, we use a simple IndexFlatIP.
|
59 |
+
self._index.value = faiss.IndexFlatIP(self._dim)
|
60 |
+
else:
|
61 |
+
if self._index is None:
|
62 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
63 |
+
|
64 |
+
# Keep a local store for metadata, IDs, etc.
|
65 |
+
# Maps <int faiss_id> → metadata (including your original ID).
|
66 |
+
self._id_to_meta.update({})
|
67 |
+
|
68 |
+
# Attempt to load an existing index + metadata from disk
|
69 |
+
self._load_faiss_index()
|
70 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
73 |
"""
|
|
|
122 |
# Normalize embeddings for cosine similarity (in-place)
|
123 |
faiss.normalize_L2(embeddings)
|
124 |
|
125 |
+
with self._storage_lock:
|
126 |
+
# Upsert logic:
|
127 |
+
# 1. Identify which vectors to remove if they exist
|
128 |
+
# 2. Remove them
|
129 |
+
# 3. Add the new vectors
|
130 |
+
existing_ids_to_remove = []
|
131 |
+
for meta, emb in zip(list_data, embeddings):
|
132 |
+
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
|
133 |
+
if faiss_internal_id is not None:
|
134 |
+
existing_ids_to_remove.append(faiss_internal_id)
|
135 |
+
|
136 |
+
if existing_ids_to_remove:
|
137 |
+
self._remove_faiss_ids(existing_ids_to_remove)
|
138 |
+
|
139 |
+
# Step 2: Add new vectors
|
140 |
+
start_idx = (self._index.value if is_multiprocess else self._index).ntotal
|
141 |
+
if is_multiprocess:
|
142 |
+
self._index.value.add(embeddings)
|
143 |
+
else:
|
144 |
+
self._index.add(embeddings)
|
145 |
+
|
146 |
+
# Step 3: Store metadata + vector for each new ID
|
147 |
+
for i, meta in enumerate(list_data):
|
148 |
+
fid = start_idx + i
|
149 |
+
# Store the raw vector so we can rebuild if something is removed
|
150 |
+
meta["__vector__"] = embeddings[i].tolist()
|
151 |
+
self._id_to_meta.update({fid: meta})
|
152 |
+
|
153 |
+
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
154 |
+
return [m["__id__"] for m in list_data]
|
155 |
|
156 |
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
157 |
"""
|
|
|
167 |
)
|
168 |
|
169 |
# Perform the similarity search
|
170 |
+
with self._storage_lock:
|
171 |
+
distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k)
|
172 |
+
|
173 |
+
distances = distances[0]
|
174 |
+
indices = indices[0]
|
175 |
+
|
176 |
+
results = []
|
177 |
+
for dist, idx in zip(distances, indices):
|
178 |
+
if idx == -1:
|
179 |
+
# Faiss returns -1 if no neighbor
|
180 |
+
continue
|
181 |
+
|
182 |
+
# Cosine similarity threshold
|
183 |
+
if dist < self.cosine_better_than_threshold:
|
184 |
+
continue
|
185 |
+
|
186 |
+
meta = self._id_to_meta.get(idx, {})
|
187 |
+
results.append(
|
188 |
+
{
|
189 |
+
**meta,
|
190 |
+
"id": meta.get("__id__"),
|
191 |
+
"distance": float(dist),
|
192 |
+
"created_at": meta.get("__created_at__"),
|
193 |
+
}
|
194 |
+
)
|
195 |
+
|
196 |
+
return results
|
197 |
|
198 |
@property
|
199 |
def client_storage(self):
|
200 |
# Return whatever structure LightRAG might need for debugging
|
201 |
+
with self._storage_lock:
|
202 |
+
return {"data": list(self._id_to_meta.values())}
|
203 |
|
204 |
async def delete(self, ids: list[str]):
|
205 |
"""
|
206 |
Delete vectors for the provided custom IDs.
|
207 |
"""
|
208 |
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
209 |
+
with self._storage_lock:
|
210 |
+
to_remove = []
|
211 |
+
for cid in ids:
|
212 |
+
fid = self._find_faiss_id_by_custom_id(cid)
|
213 |
+
if fid is not None:
|
214 |
+
to_remove.append(fid)
|
215 |
+
|
216 |
+
if to_remove:
|
217 |
+
self._remove_faiss_ids(to_remove)
|
218 |
+
logger.debug(
|
219 |
+
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
220 |
+
)
|
221 |
|
222 |
async def delete_entity(self, entity_name: str) -> None:
|
223 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
|
|
229 |
Delete relations for a given entity by scanning metadata.
|
230 |
"""
|
231 |
logger.debug(f"Searching relations for entity {entity_name}")
|
232 |
+
with self._storage_lock:
|
233 |
+
relations = []
|
234 |
+
for fid, meta in self._id_to_meta.items():
|
235 |
+
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
|
236 |
+
relations.append(fid)
|
237 |
|
238 |
+
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
239 |
+
if relations:
|
240 |
+
self._remove_faiss_ids(relations)
|
241 |
+
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
242 |
|
243 |
async def index_done_callback(self) -> None:
|
244 |
+
with self._storage_lock:
|
245 |
+
self._save_faiss_index()
|
246 |
|
247 |
# --------------------------------------------------------------------------------
|
248 |
# Internal helper methods
|
|
|
252 |
"""
|
253 |
Return the Faiss internal ID for a given custom ID, or None if not found.
|
254 |
"""
|
255 |
+
with self._storage_lock:
|
256 |
+
for fid, meta in self._id_to_meta.items():
|
257 |
+
if meta.get("__id__") == custom_id:
|
258 |
+
return fid
|
259 |
+
return None
|
260 |
|
261 |
def _remove_faiss_ids(self, fid_list):
|
262 |
"""
|
|
|
264 |
Because IndexFlatIP doesn't support 'removals',
|
265 |
we rebuild the index excluding those vectors.
|
266 |
"""
|
267 |
+
with self._storage_lock:
|
268 |
+
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
|
269 |
+
|
270 |
+
# Rebuild the index
|
271 |
+
vectors_to_keep = []
|
272 |
+
new_id_to_meta = {}
|
273 |
+
for new_fid, old_fid in enumerate(keep_fids):
|
274 |
+
vec_meta = self._id_to_meta[old_fid]
|
275 |
+
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
|
276 |
+
new_id_to_meta[new_fid] = vec_meta
|
277 |
+
|
278 |
+
# Re-init index
|
279 |
+
new_index = faiss.IndexFlatIP(self._dim)
|
280 |
+
if vectors_to_keep:
|
281 |
+
arr = np.array(vectors_to_keep, dtype=np.float32)
|
282 |
+
new_index.add(arr)
|
283 |
+
if is_multiprocess:
|
284 |
+
self._index.value = new_index
|
285 |
+
else:
|
286 |
+
self._index = new_index
|
287 |
+
|
288 |
+
self._id_to_meta.update(new_id_to_meta)
|
289 |
|
290 |
def _save_faiss_index(self):
|
291 |
"""
|
292 |
Save the current Faiss index + metadata to disk so it can persist across runs.
|
293 |
"""
|
294 |
+
with self._storage_lock:
|
295 |
+
faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file)
|
296 |
|
297 |
+
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
|
298 |
+
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
|
299 |
+
# We'll keep the int -> dict, but JSON requires string keys.
|
300 |
+
serializable_dict = {}
|
301 |
+
for fid, meta in self._id_to_meta.items():
|
302 |
+
serializable_dict[str(fid)] = meta
|
303 |
|
304 |
+
with open(self._meta_file, "w", encoding="utf-8") as f:
|
305 |
+
json.dump(serializable_dict, f)
|
306 |
|
307 |
def _load_faiss_index(self):
|
308 |
"""
|
|
|
315 |
|
316 |
try:
|
317 |
# Load the Faiss index
|
318 |
+
loaded_index = faiss.read_index(self._faiss_index_file)
|
319 |
+
if is_multiprocess:
|
320 |
+
self._index.value = loaded_index
|
321 |
+
else:
|
322 |
+
self._index = loaded_index
|
323 |
+
|
324 |
# Load metadata
|
325 |
with open(self._meta_file, "r", encoding="utf-8") as f:
|
326 |
stored_dict = json.load(f)
|
327 |
|
328 |
# Convert string keys back to int
|
329 |
+
self._id_to_meta.update({})
|
330 |
for fid_str, meta in stored_dict.items():
|
331 |
fid = int(fid_str)
|
332 |
self._id_to_meta[fid] = meta
|
333 |
|
334 |
logger.info(
|
335 |
+
f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}"
|
336 |
)
|
337 |
except Exception as e:
|
338 |
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
339 |
logger.warning("Starting with an empty Faiss index.")
|
340 |
+
new_index = faiss.IndexFlatIP(self._dim)
|
341 |
+
if is_multiprocess:
|
342 |
+
self._index.value = new_index
|
343 |
+
else:
|
344 |
+
self._index = new_index
|
345 |
+
self._id_to_meta.update({})
|
lightrag/kg/json_doc_status_impl.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
import os
|
3 |
from typing import Any, Union, final
|
4 |
-
import threading
|
5 |
|
6 |
from lightrag.base import (
|
7 |
DocProcessingStatus,
|
@@ -13,26 +12,7 @@ from lightrag.utils import (
|
|
13 |
logger,
|
14 |
write_json,
|
15 |
)
|
16 |
-
from
|
17 |
-
|
18 |
-
# Global variables for shared memory management
|
19 |
-
_init_lock = threading.Lock()
|
20 |
-
_manager = None
|
21 |
-
_shared_doc_status_data = None
|
22 |
-
|
23 |
-
|
24 |
-
def _get_manager():
|
25 |
-
"""Get or create the global manager instance"""
|
26 |
-
global _manager, _shared_doc_status_data
|
27 |
-
with _init_lock:
|
28 |
-
if _manager is None:
|
29 |
-
try:
|
30 |
-
_manager = main_process_manager
|
31 |
-
_shared_doc_status_data = _manager.dict()
|
32 |
-
except Exception as e:
|
33 |
-
logger.error(f"Failed to initialize shared memory manager: {e}")
|
34 |
-
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
35 |
-
return _manager
|
36 |
|
37 |
|
38 |
@final
|
@@ -43,45 +23,32 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|
43 |
def __post_init__(self):
|
44 |
working_dir = self.global_config["working_dir"]
|
45 |
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
if self.namespace not in _shared_doc_status_data:
|
52 |
-
with _init_lock:
|
53 |
-
if self.namespace not in _shared_doc_status_data:
|
54 |
-
try:
|
55 |
-
initial_data = load_json(self._file_name) or {}
|
56 |
-
_shared_doc_status_data[self.namespace] = initial_data
|
57 |
-
except Exception as e:
|
58 |
-
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
|
59 |
-
raise RuntimeError(f"Shared data initialization failed: {e}")
|
60 |
-
|
61 |
-
try:
|
62 |
-
self._data = _shared_doc_status_data[self.namespace]
|
63 |
-
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
64 |
-
except Exception as e:
|
65 |
-
logger.error(f"Failed to access shared memory: {e}")
|
66 |
-
raise RuntimeError(f"Cannot access shared memory: {e}")
|
67 |
|
68 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
69 |
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
70 |
-
|
|
|
71 |
|
72 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
73 |
result: list[dict[str, Any]] = []
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
78 |
return result
|
79 |
|
80 |
async def get_status_counts(self) -> dict[str, int]:
|
81 |
"""Get counts of documents in each status"""
|
82 |
counts = {status.value: 0 for status in DocStatus}
|
83 |
-
|
84 |
-
|
|
|
85 |
return counts
|
86 |
|
87 |
async def get_docs_by_status(
|
@@ -89,39 +56,46 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|
89 |
) -> dict[str, DocProcessingStatus]:
|
90 |
"""Get all documents with a specific status"""
|
91 |
result = {}
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
104 |
return result
|
105 |
|
106 |
async def index_done_callback(self) -> None:
|
107 |
-
|
|
|
|
|
108 |
|
109 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
110 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
111 |
if not data:
|
112 |
return
|
113 |
|
114 |
-
self.
|
|
|
115 |
await self.index_done_callback()
|
116 |
|
117 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
118 |
-
|
|
|
119 |
|
120 |
async def delete(self, doc_ids: list[str]):
|
121 |
-
|
122 |
-
|
|
|
123 |
await self.index_done_callback()
|
124 |
|
125 |
async def drop(self) -> None:
|
126 |
"""Drop the storage"""
|
127 |
-
self.
|
|
|
|
1 |
from dataclasses import dataclass
|
2 |
import os
|
3 |
from typing import Any, Union, final
|
|
|
4 |
|
5 |
from lightrag.base import (
|
6 |
DocProcessingStatus,
|
|
|
12 |
logger,
|
13 |
write_json,
|
14 |
)
|
15 |
+
from .shared_storage import get_namespace_data, get_storage_lock
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
@final
|
|
|
23 |
def __post_init__(self):
|
24 |
working_dir = self.global_config["working_dir"]
|
25 |
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
26 |
+
self._storage_lock = get_storage_lock()
|
27 |
+
self._data = get_namespace_data(self.namespace)
|
28 |
+
with self._storage_lock:
|
29 |
+
self._data.update(load_json(self._file_name) or {})
|
30 |
+
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
33 |
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
34 |
+
with self._storage_lock:
|
35 |
+
return set(keys) - set(self._data.keys())
|
36 |
|
37 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
38 |
result: list[dict[str, Any]] = []
|
39 |
+
with self._storage_lock:
|
40 |
+
for id in ids:
|
41 |
+
data = self._data.get(id, None)
|
42 |
+
if data:
|
43 |
+
result.append(data)
|
44 |
return result
|
45 |
|
46 |
async def get_status_counts(self) -> dict[str, int]:
|
47 |
"""Get counts of documents in each status"""
|
48 |
counts = {status.value: 0 for status in DocStatus}
|
49 |
+
with self._storage_lock:
|
50 |
+
for doc in self._data.values():
|
51 |
+
counts[doc["status"]] += 1
|
52 |
return counts
|
53 |
|
54 |
async def get_docs_by_status(
|
|
|
56 |
) -> dict[str, DocProcessingStatus]:
|
57 |
"""Get all documents with a specific status"""
|
58 |
result = {}
|
59 |
+
with self._storage_lock:
|
60 |
+
for k, v in self._data.items():
|
61 |
+
if v["status"] == status.value:
|
62 |
+
try:
|
63 |
+
# Make a copy of the data to avoid modifying the original
|
64 |
+
data = v.copy()
|
65 |
+
# If content is missing, use content_summary as content
|
66 |
+
if "content" not in data and "content_summary" in data:
|
67 |
+
data["content"] = data["content_summary"]
|
68 |
+
result[k] = DocProcessingStatus(**data)
|
69 |
+
except KeyError as e:
|
70 |
+
logger.error(f"Missing required field for document {k}: {e}")
|
71 |
+
continue
|
72 |
return result
|
73 |
|
74 |
async def index_done_callback(self) -> None:
|
75 |
+
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏
|
76 |
+
with self._storage_lock:
|
77 |
+
write_json(self._data, self._file_name)
|
78 |
|
79 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
80 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
81 |
if not data:
|
82 |
return
|
83 |
|
84 |
+
with self._storage_lock:
|
85 |
+
self._data.update(data)
|
86 |
await self.index_done_callback()
|
87 |
|
88 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
89 |
+
with self._storage_lock:
|
90 |
+
return self._data.get(id)
|
91 |
|
92 |
async def delete(self, doc_ids: list[str]):
|
93 |
+
with self._storage_lock:
|
94 |
+
for doc_id in doc_ids:
|
95 |
+
self._data.pop(doc_id, None)
|
96 |
await self.index_done_callback()
|
97 |
|
98 |
async def drop(self) -> None:
|
99 |
"""Drop the storage"""
|
100 |
+
with self._storage_lock:
|
101 |
+
self._data.clear()
|
lightrag/kg/json_kv_impl.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
-
import asyncio
|
2 |
import os
|
3 |
from dataclasses import dataclass
|
4 |
from typing import Any, final
|
5 |
-
import threading
|
6 |
|
7 |
from lightrag.base import (
|
8 |
BaseKVStorage,
|
@@ -12,26 +10,7 @@ from lightrag.utils import (
|
|
12 |
logger,
|
13 |
write_json,
|
14 |
)
|
15 |
-
from
|
16 |
-
|
17 |
-
# Global variables for shared memory management
|
18 |
-
_init_lock = threading.Lock()
|
19 |
-
_manager = None
|
20 |
-
_shared_kv_data = None
|
21 |
-
|
22 |
-
|
23 |
-
def _get_manager():
|
24 |
-
"""Get or create the global manager instance"""
|
25 |
-
global _manager, _shared_kv_data
|
26 |
-
with _init_lock:
|
27 |
-
if _manager is None:
|
28 |
-
try:
|
29 |
-
_manager = main_process_manager
|
30 |
-
_shared_kv_data = _manager.dict()
|
31 |
-
except Exception as e:
|
32 |
-
logger.error(f"Failed to initialize shared memory manager: {e}")
|
33 |
-
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
34 |
-
return _manager
|
35 |
|
36 |
|
37 |
@final
|
@@ -39,57 +18,49 @@ def _get_manager():
|
|
39 |
class JsonKVStorage(BaseKVStorage):
|
40 |
def __post_init__(self):
|
41 |
working_dir = self.global_config["working_dir"]
|
42 |
-
self.
|
43 |
-
self.
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
with _init_lock:
|
51 |
-
if self.namespace not in _shared_kv_data:
|
52 |
-
try:
|
53 |
-
initial_data = load_json(self._file_name) or {}
|
54 |
-
_shared_kv_data[self.namespace] = initial_data
|
55 |
-
except Exception as e:
|
56 |
-
logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
|
57 |
-
raise RuntimeError(f"Shared data initialization failed: {e}")
|
58 |
-
|
59 |
-
try:
|
60 |
-
self._data = _shared_kv_data[self.namespace]
|
61 |
-
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
62 |
-
except Exception as e:
|
63 |
-
logger.error(f"Failed to access shared memory: {e}")
|
64 |
-
raise RuntimeError(f"Cannot access shared memory: {e}")
|
65 |
|
66 |
async def index_done_callback(self) -> None:
|
67 |
-
|
|
|
|
|
68 |
|
69 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
70 |
-
|
|
|
71 |
|
72 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
81 |
|
82 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
83 |
-
|
|
|
84 |
|
85 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
86 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
87 |
if not data:
|
88 |
return
|
89 |
-
|
90 |
-
|
|
|
91 |
|
92 |
async def delete(self, ids: list[str]) -> None:
|
93 |
-
|
94 |
-
|
|
|
95 |
await self.index_done_callback()
|
|
|
|
|
1 |
import os
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Any, final
|
|
|
4 |
|
5 |
from lightrag.base import (
|
6 |
BaseKVStorage,
|
|
|
10 |
logger,
|
11 |
write_json,
|
12 |
)
|
13 |
+
from .shared_storage import get_namespace_data, get_storage_lock
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
@final
|
|
|
18 |
class JsonKVStorage(BaseKVStorage):
|
19 |
def __post_init__(self):
|
20 |
working_dir = self.global_config["working_dir"]
|
21 |
+
self._storage_lock = get_storage_lock()
|
22 |
+
self._data = get_namespace_data(self.namespace)
|
23 |
+
with self._storage_lock:
|
24 |
+
if not self._data:
|
25 |
+
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
26 |
+
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
27 |
+
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
28 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
async def index_done_callback(self) -> None:
|
31 |
+
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏
|
32 |
+
with self._storage_lock:
|
33 |
+
write_json(self._data, self._file_name)
|
34 |
|
35 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
36 |
+
with self._storage_lock:
|
37 |
+
return self._data.get(id)
|
38 |
|
39 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
40 |
+
with self._storage_lock:
|
41 |
+
return [
|
42 |
+
(
|
43 |
+
{k: v for k, v in self._data[id].items()}
|
44 |
+
if self._data.get(id, None)
|
45 |
+
else None
|
46 |
+
)
|
47 |
+
for id in ids
|
48 |
+
]
|
49 |
|
50 |
async def filter_keys(self, keys: set[str]) -> set[str]:
|
51 |
+
with self._storage_lock:
|
52 |
+
return set(keys) - set(self._data.keys())
|
53 |
|
54 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
55 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
56 |
if not data:
|
57 |
return
|
58 |
+
with self._storage_lock:
|
59 |
+
left_data = {k: v for k, v in data.items() if k not in self._data}
|
60 |
+
self._data.update(left_data)
|
61 |
|
62 |
async def delete(self, ids: list[str]) -> None:
|
63 |
+
with self._storage_lock:
|
64 |
+
for doc_id in ids:
|
65 |
+
self._data.pop(doc_id, None)
|
66 |
await self.index_done_callback()
|
lightrag/kg/nano_vector_db_impl.py
CHANGED
@@ -3,50 +3,29 @@ import os
|
|
3 |
from typing import Any, final
|
4 |
from dataclasses import dataclass
|
5 |
import numpy as np
|
6 |
-
import threading
|
7 |
import time
|
8 |
|
9 |
from lightrag.utils import (
|
10 |
logger,
|
11 |
compute_mdhash_id,
|
12 |
)
|
13 |
-
from lightrag.api.utils_api import manager as main_process_manager
|
14 |
import pipmaster as pm
|
15 |
-
from lightrag.base import
|
16 |
-
|
17 |
-
)
|
18 |
|
19 |
if not pm.is_installed("nano-vectordb"):
|
20 |
pm.install("nano-vectordb")
|
21 |
|
22 |
from nano_vectordb import NanoVectorDB
|
23 |
|
24 |
-
# Global variables for shared memory management
|
25 |
-
_init_lock = threading.Lock()
|
26 |
-
_manager = None
|
27 |
-
_shared_vector_clients = None
|
28 |
-
|
29 |
-
|
30 |
-
def _get_manager():
|
31 |
-
"""Get or create the global manager instance"""
|
32 |
-
global _manager, _shared_vector_clients
|
33 |
-
with _init_lock:
|
34 |
-
if _manager is None:
|
35 |
-
try:
|
36 |
-
_manager = main_process_manager
|
37 |
-
_shared_vector_clients = _manager.dict()
|
38 |
-
except Exception as e:
|
39 |
-
logger.error(f"Failed to initialize shared memory manager: {e}")
|
40 |
-
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
41 |
-
return _manager
|
42 |
-
|
43 |
|
44 |
@final
|
45 |
@dataclass
|
46 |
class NanoVectorDBStorage(BaseVectorStorage):
|
47 |
def __post_init__(self):
|
48 |
# Initialize lock only for file operations
|
49 |
-
self.
|
|
|
50 |
# Use global config value if specified, otherwise use default
|
51 |
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
52 |
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
@@ -61,28 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
61 |
)
|
62 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
63 |
|
64 |
-
|
65 |
-
_get_manager()
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}")
|
79 |
-
raise RuntimeError(f"Vector DB client initialization failed: {e}")
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
|
87 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
88 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
@@ -104,6 +82,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
104 |
for i in range(0, len(contents), self._max_batch_size)
|
105 |
]
|
106 |
|
|
|
107 |
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
108 |
embeddings_list = await asyncio.gather(*embedding_tasks)
|
109 |
|
@@ -111,7 +90,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
111 |
if len(embeddings) == len(list_data):
|
112 |
for i, d in enumerate(list_data):
|
113 |
d["__vector__"] = embeddings[i]
|
114 |
-
|
|
|
|
|
115 |
return results
|
116 |
else:
|
117 |
# sometimes the embedding is not returned correctly. just log it.
|
@@ -120,27 +101,32 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
120 |
)
|
121 |
|
122 |
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
|
|
123 |
embedding = await self.embedding_func([query])
|
124 |
embedding = embedding[0]
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
139 |
return results
|
140 |
|
141 |
@property
|
142 |
def client_storage(self):
|
143 |
-
|
|
|
144 |
|
145 |
async def delete(self, ids: list[str]):
|
146 |
"""Delete vectors with specified IDs
|
@@ -149,8 +135,10 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
149 |
ids: List of vector IDs to be deleted
|
150 |
"""
|
151 |
try:
|
152 |
-
self.
|
153 |
-
|
|
|
|
|
154 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
155 |
)
|
156 |
except Exception as e:
|
@@ -162,35 +150,42 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
162 |
logger.debug(
|
163 |
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
164 |
)
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
171 |
except Exception as e:
|
172 |
logger.error(f"Error deleting entity {entity_name}: {e}")
|
173 |
|
174 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
175 |
try:
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
191 |
except Exception as e:
|
192 |
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
193 |
|
194 |
async def index_done_callback(self) -> None:
|
195 |
-
|
196 |
-
self.
|
|
|
|
3 |
from typing import Any, final
|
4 |
from dataclasses import dataclass
|
5 |
import numpy as np
|
|
|
6 |
import time
|
7 |
|
8 |
from lightrag.utils import (
|
9 |
logger,
|
10 |
compute_mdhash_id,
|
11 |
)
|
|
|
12 |
import pipmaster as pm
|
13 |
+
from lightrag.base import BaseVectorStorage
|
14 |
+
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
|
|
|
15 |
|
16 |
if not pm.is_installed("nano-vectordb"):
|
17 |
pm.install("nano-vectordb")
|
18 |
|
19 |
from nano_vectordb import NanoVectorDB
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
@final
|
23 |
@dataclass
|
24 |
class NanoVectorDBStorage(BaseVectorStorage):
|
25 |
def __post_init__(self):
|
26 |
# Initialize lock only for file operations
|
27 |
+
self._storage_lock = get_storage_lock()
|
28 |
+
|
29 |
# Use global config value if specified, otherwise use default
|
30 |
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
31 |
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
|
|
40 |
)
|
41 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
42 |
|
43 |
+
self._client = get_namespace_object(self.namespace)
|
|
|
44 |
|
45 |
+
with self._storage_lock:
|
46 |
+
if is_multiprocess:
|
47 |
+
if self._client.value is None:
|
48 |
+
self._client.value = NanoVectorDB(
|
49 |
+
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
if self._client is None:
|
53 |
+
self._client = NanoVectorDB(
|
54 |
+
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
55 |
+
)
|
|
|
|
|
56 |
|
57 |
+
logger.info(f"Initialized vector DB client for namespace {self.namespace}")
|
58 |
+
|
59 |
+
def _get_client(self):
|
60 |
+
"""Get the appropriate client instance based on multiprocess mode"""
|
61 |
+
if is_multiprocess:
|
62 |
+
return self._client.value
|
63 |
+
return self._client
|
64 |
|
65 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
66 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
|
|
82 |
for i in range(0, len(contents), self._max_batch_size)
|
83 |
]
|
84 |
|
85 |
+
# Execute embedding outside of lock to avoid long lock times
|
86 |
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
87 |
embeddings_list = await asyncio.gather(*embedding_tasks)
|
88 |
|
|
|
90 |
if len(embeddings) == len(list_data):
|
91 |
for i, d in enumerate(list_data):
|
92 |
d["__vector__"] = embeddings[i]
|
93 |
+
with self._storage_lock:
|
94 |
+
client = self._get_client()
|
95 |
+
results = client.upsert(datas=list_data)
|
96 |
return results
|
97 |
else:
|
98 |
# sometimes the embedding is not returned correctly. just log it.
|
|
|
101 |
)
|
102 |
|
103 |
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
104 |
+
# Execute embedding outside of lock to avoid long lock times
|
105 |
embedding = await self.embedding_func([query])
|
106 |
embedding = embedding[0]
|
107 |
+
|
108 |
+
with self._storage_lock:
|
109 |
+
client = self._get_client()
|
110 |
+
results = client.query(
|
111 |
+
query=embedding,
|
112 |
+
top_k=top_k,
|
113 |
+
better_than_threshold=self.cosine_better_than_threshold,
|
114 |
+
)
|
115 |
+
results = [
|
116 |
+
{
|
117 |
+
**dp,
|
118 |
+
"id": dp["__id__"],
|
119 |
+
"distance": dp["__metrics__"],
|
120 |
+
"created_at": dp.get("__created_at__"),
|
121 |
+
}
|
122 |
+
for dp in results
|
123 |
+
]
|
124 |
return results
|
125 |
|
126 |
@property
|
127 |
def client_storage(self):
|
128 |
+
client = self._get_client()
|
129 |
+
return getattr(client, "_NanoVectorDB__storage")
|
130 |
|
131 |
async def delete(self, ids: list[str]):
|
132 |
"""Delete vectors with specified IDs
|
|
|
135 |
ids: List of vector IDs to be deleted
|
136 |
"""
|
137 |
try:
|
138 |
+
with self._storage_lock:
|
139 |
+
client = self._get_client()
|
140 |
+
client.delete(ids)
|
141 |
+
logger.debug(
|
142 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
143 |
)
|
144 |
except Exception as e:
|
|
|
150 |
logger.debug(
|
151 |
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
152 |
)
|
153 |
+
|
154 |
+
with self._storage_lock:
|
155 |
+
client = self._get_client()
|
156 |
+
# Check if the entity exists
|
157 |
+
if client.get([entity_id]):
|
158 |
+
client.delete([entity_id])
|
159 |
+
logger.debug(f"Successfully deleted entity {entity_name}")
|
160 |
+
else:
|
161 |
+
logger.debug(f"Entity {entity_name} not found in storage")
|
162 |
except Exception as e:
|
163 |
logger.error(f"Error deleting entity {entity_name}: {e}")
|
164 |
|
165 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
166 |
try:
|
167 |
+
with self._storage_lock:
|
168 |
+
client = self._get_client()
|
169 |
+
storage = getattr(client, "_NanoVectorDB__storage")
|
170 |
+
relations = [
|
171 |
+
dp
|
172 |
+
for dp in storage["data"]
|
173 |
+
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
174 |
+
]
|
175 |
+
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
176 |
+
ids_to_delete = [relation["__id__"] for relation in relations]
|
177 |
+
|
178 |
+
if ids_to_delete:
|
179 |
+
client.delete(ids_to_delete)
|
180 |
+
logger.debug(
|
181 |
+
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
|
182 |
+
)
|
183 |
+
else:
|
184 |
+
logger.debug(f"No relations found for entity {entity_name}")
|
185 |
except Exception as e:
|
186 |
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
187 |
|
188 |
async def index_done_callback(self) -> None:
|
189 |
+
with self._storage_lock:
|
190 |
+
client = self._get_client()
|
191 |
+
client.save()
|
lightrag/kg/networkx_impl.py
CHANGED
@@ -1,18 +1,13 @@
|
|
1 |
import os
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Any, final
|
4 |
-
import threading
|
5 |
import numpy as np
|
6 |
|
7 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
8 |
-
from lightrag.utils import
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
from lightrag.base import (
|
14 |
-
BaseGraphStorage,
|
15 |
-
)
|
16 |
import pipmaster as pm
|
17 |
|
18 |
if not pm.is_installed("networkx"):
|
@@ -24,25 +19,6 @@ if not pm.is_installed("graspologic"):
|
|
24 |
import networkx as nx
|
25 |
from graspologic import embed
|
26 |
|
27 |
-
# Global variables for shared memory management
|
28 |
-
_init_lock = threading.Lock()
|
29 |
-
_manager = None
|
30 |
-
_shared_graphs = None
|
31 |
-
|
32 |
-
|
33 |
-
def _get_manager():
|
34 |
-
"""Get or create the global manager instance"""
|
35 |
-
global _manager, _shared_graphs
|
36 |
-
with _init_lock:
|
37 |
-
if _manager is None:
|
38 |
-
try:
|
39 |
-
_manager = main_process_manager
|
40 |
-
_shared_graphs = _manager.dict()
|
41 |
-
except Exception as e:
|
42 |
-
logger.error(f"Failed to initialize shared memory manager: {e}")
|
43 |
-
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
44 |
-
return _manager
|
45 |
-
|
46 |
|
47 |
@final
|
48 |
@dataclass
|
@@ -97,76 +73,98 @@ class NetworkXStorage(BaseGraphStorage):
|
|
97 |
self._graphml_xml_file = os.path.join(
|
98 |
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
99 |
)
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
try:
|
120 |
-
self._graph = _shared_graphs[self.namespace]
|
121 |
-
self._node_embed_algorithms = {
|
122 |
"node2vec": self._node2vec_embed,
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
127 |
|
128 |
async def index_done_callback(self) -> None:
|
129 |
-
|
|
|
|
|
130 |
|
131 |
async def has_node(self, node_id: str) -> bool:
|
132 |
-
|
|
|
|
|
133 |
|
134 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
135 |
-
|
|
|
|
|
136 |
|
137 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
138 |
-
|
|
|
|
|
139 |
|
140 |
async def node_degree(self, node_id: str) -> int:
|
141 |
-
|
|
|
|
|
142 |
|
143 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
144 |
-
|
|
|
|
|
145 |
|
146 |
async def get_edge(
|
147 |
self, source_node_id: str, target_node_id: str
|
148 |
) -> dict[str, str] | None:
|
149 |
-
|
|
|
|
|
150 |
|
151 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
155 |
|
156 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
157 |
-
self.
|
|
|
|
|
158 |
|
159 |
async def upsert_edge(
|
160 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
161 |
) -> None:
|
162 |
-
self.
|
|
|
|
|
163 |
|
164 |
async def delete_node(self, node_id: str) -> None:
|
165 |
-
|
166 |
-
self.
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
170 |
|
171 |
async def embed_nodes(
|
172 |
self, algorithm: str
|
@@ -175,14 +173,15 @@ class NetworkXStorage(BaseGraphStorage):
|
|
175 |
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
176 |
return await self._node_embed_algorithms[algorithm]()
|
177 |
|
178 |
-
#
|
179 |
async def _node2vec_embed(self):
|
180 |
-
|
181 |
-
self.
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
186 |
return embeddings, nodes_ids
|
187 |
|
188 |
def remove_nodes(self, nodes: list[str]):
|
@@ -191,9 +190,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|
191 |
Args:
|
192 |
nodes: List of node IDs to be deleted
|
193 |
"""
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
197 |
|
198 |
def remove_edges(self, edges: list[tuple[str, str]]):
|
199 |
"""Delete multiple edges
|
@@ -201,9 +202,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|
201 |
Args:
|
202 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
203 |
"""
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
207 |
|
208 |
async def get_all_labels(self) -> list[str]:
|
209 |
"""
|
@@ -211,9 +214,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|
211 |
Returns:
|
212 |
[label1, label2, ...] # Alphabetically sorted label list
|
213 |
"""
|
214 |
-
|
215 |
-
|
216 |
-
labels
|
|
|
|
|
217 |
|
218 |
# Return sorted list
|
219 |
return sorted(list(labels))
|
@@ -235,87 +240,86 @@ class NetworkXStorage(BaseGraphStorage):
|
|
235 |
seen_nodes = set()
|
236 |
seen_edges = set()
|
237 |
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
:
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
272 |
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
293 |
)
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
)
|
315 |
-
|
316 |
-
seen_edges.add(edge_id)
|
317 |
-
|
318 |
-
# logger.info(result.edges)
|
319 |
|
320 |
logger.info(
|
321 |
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
|
|
1 |
import os
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Any, final
|
|
|
4 |
import numpy as np
|
5 |
|
6 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
7 |
+
from lightrag.utils import logger
|
8 |
+
from lightrag.base import BaseGraphStorage
|
9 |
+
from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
|
10 |
+
|
|
|
|
|
|
|
|
|
11 |
import pipmaster as pm
|
12 |
|
13 |
if not pm.is_installed("networkx"):
|
|
|
19 |
import networkx as nx
|
20 |
from graspologic import embed
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
@final
|
24 |
@dataclass
|
|
|
73 |
self._graphml_xml_file = os.path.join(
|
74 |
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
75 |
)
|
76 |
+
self._storage_lock = get_storage_lock()
|
77 |
+
self._graph = get_namespace_object(self.namespace)
|
78 |
+
with self._storage_lock:
|
79 |
+
if is_multiprocess:
|
80 |
+
if self._graph.value is None:
|
81 |
+
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
82 |
+
self._graph.value = preloaded_graph or nx.Graph()
|
83 |
+
logger.info(
|
84 |
+
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
if self._graph is None:
|
88 |
+
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
89 |
+
self._graph = preloaded_graph or nx.Graph()
|
90 |
+
logger.info(
|
91 |
+
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
92 |
+
)
|
93 |
+
|
94 |
+
self._node_embed_algorithms = {
|
|
|
|
|
|
|
95 |
"node2vec": self._node2vec_embed,
|
96 |
+
}
|
97 |
+
|
98 |
+
def _get_graph(self):
|
99 |
+
"""Get the appropriate graph instance based on multiprocess mode"""
|
100 |
+
if is_multiprocess:
|
101 |
+
return self._graph.value
|
102 |
+
return self._graph
|
103 |
|
104 |
async def index_done_callback(self) -> None:
|
105 |
+
with self._storage_lock:
|
106 |
+
graph = self._get_graph()
|
107 |
+
NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file)
|
108 |
|
109 |
async def has_node(self, node_id: str) -> bool:
|
110 |
+
with self._storage_lock:
|
111 |
+
graph = self._get_graph()
|
112 |
+
return graph.has_node(node_id)
|
113 |
|
114 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
115 |
+
with self._storage_lock:
|
116 |
+
graph = self._get_graph()
|
117 |
+
return graph.has_edge(source_node_id, target_node_id)
|
118 |
|
119 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
120 |
+
with self._storage_lock:
|
121 |
+
graph = self._get_graph()
|
122 |
+
return graph.nodes.get(node_id)
|
123 |
|
124 |
async def node_degree(self, node_id: str) -> int:
|
125 |
+
with self._storage_lock:
|
126 |
+
graph = self._get_graph()
|
127 |
+
return graph.degree(node_id)
|
128 |
|
129 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
130 |
+
with self._storage_lock:
|
131 |
+
graph = self._get_graph()
|
132 |
+
return graph.degree(src_id) + graph.degree(tgt_id)
|
133 |
|
134 |
async def get_edge(
|
135 |
self, source_node_id: str, target_node_id: str
|
136 |
) -> dict[str, str] | None:
|
137 |
+
with self._storage_lock:
|
138 |
+
graph = self._get_graph()
|
139 |
+
return graph.edges.get((source_node_id, target_node_id))
|
140 |
|
141 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
142 |
+
with self._storage_lock:
|
143 |
+
graph = self._get_graph()
|
144 |
+
if graph.has_node(source_node_id):
|
145 |
+
return list(graph.edges(source_node_id))
|
146 |
+
return None
|
147 |
|
148 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
149 |
+
with self._storage_lock:
|
150 |
+
graph = self._get_graph()
|
151 |
+
graph.add_node(node_id, **node_data)
|
152 |
|
153 |
async def upsert_edge(
|
154 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
155 |
) -> None:
|
156 |
+
with self._storage_lock:
|
157 |
+
graph = self._get_graph()
|
158 |
+
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
159 |
|
160 |
async def delete_node(self, node_id: str) -> None:
|
161 |
+
with self._storage_lock:
|
162 |
+
graph = self._get_graph()
|
163 |
+
if graph.has_node(node_id):
|
164 |
+
graph.remove_node(node_id)
|
165 |
+
logger.debug(f"Node {node_id} deleted from the graph.")
|
166 |
+
else:
|
167 |
+
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
168 |
|
169 |
async def embed_nodes(
|
170 |
self, algorithm: str
|
|
|
173 |
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
174 |
return await self._node_embed_algorithms[algorithm]()
|
175 |
|
176 |
+
# TODO: NOT USED
|
177 |
async def _node2vec_embed(self):
|
178 |
+
with self._storage_lock:
|
179 |
+
graph = self._get_graph()
|
180 |
+
embeddings, nodes = embed.node2vec_embed(
|
181 |
+
graph,
|
182 |
+
**self.global_config["node2vec_params"],
|
183 |
+
)
|
184 |
+
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
|
185 |
return embeddings, nodes_ids
|
186 |
|
187 |
def remove_nodes(self, nodes: list[str]):
|
|
|
190 |
Args:
|
191 |
nodes: List of node IDs to be deleted
|
192 |
"""
|
193 |
+
with self._storage_lock:
|
194 |
+
graph = self._get_graph()
|
195 |
+
for node in nodes:
|
196 |
+
if graph.has_node(node):
|
197 |
+
graph.remove_node(node)
|
198 |
|
199 |
def remove_edges(self, edges: list[tuple[str, str]]):
|
200 |
"""Delete multiple edges
|
|
|
202 |
Args:
|
203 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
204 |
"""
|
205 |
+
with self._storage_lock:
|
206 |
+
graph = self._get_graph()
|
207 |
+
for source, target in edges:
|
208 |
+
if graph.has_edge(source, target):
|
209 |
+
graph.remove_edge(source, target)
|
210 |
|
211 |
async def get_all_labels(self) -> list[str]:
|
212 |
"""
|
|
|
214 |
Returns:
|
215 |
[label1, label2, ...] # Alphabetically sorted label list
|
216 |
"""
|
217 |
+
with self._storage_lock:
|
218 |
+
graph = self._get_graph()
|
219 |
+
labels = set()
|
220 |
+
for node in graph.nodes():
|
221 |
+
labels.add(str(node)) # Add node id as a label
|
222 |
|
223 |
# Return sorted list
|
224 |
return sorted(list(labels))
|
|
|
240 |
seen_nodes = set()
|
241 |
seen_edges = set()
|
242 |
|
243 |
+
with self._storage_lock:
|
244 |
+
graph = self._get_graph()
|
245 |
+
|
246 |
+
# Handle special case for "*" label
|
247 |
+
if node_label == "*":
|
248 |
+
# For "*", return the entire graph including all nodes and edges
|
249 |
+
subgraph = graph.copy() # Create a copy to avoid modifying the original graph
|
250 |
+
else:
|
251 |
+
# Find nodes with matching node id (partial match)
|
252 |
+
nodes_to_explore = []
|
253 |
+
for n, attr in graph.nodes(data=True):
|
254 |
+
if node_label in str(n): # Use partial matching
|
255 |
+
nodes_to_explore.append(n)
|
256 |
+
|
257 |
+
if not nodes_to_explore:
|
258 |
+
logger.warning(f"No nodes found with label {node_label}")
|
259 |
+
return result
|
260 |
+
|
261 |
+
# Get subgraph using ego_graph
|
262 |
+
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
|
263 |
+
|
264 |
+
# Check if number of nodes exceeds max_graph_nodes
|
265 |
+
max_graph_nodes = 500
|
266 |
+
if len(subgraph.nodes()) > max_graph_nodes:
|
267 |
+
origin_nodes = len(subgraph.nodes())
|
268 |
+
node_degrees = dict(subgraph.degree())
|
269 |
+
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
|
270 |
+
:max_graph_nodes
|
271 |
+
]
|
272 |
+
top_node_ids = [node[0] for node in top_nodes]
|
273 |
+
# Create new subgraph with only top nodes
|
274 |
+
subgraph = subgraph.subgraph(top_node_ids)
|
275 |
+
logger.info(
|
276 |
+
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
|
277 |
+
)
|
278 |
|
279 |
+
# Add nodes to result
|
280 |
+
for node in subgraph.nodes():
|
281 |
+
if str(node) in seen_nodes:
|
282 |
+
continue
|
283 |
+
|
284 |
+
node_data = dict(subgraph.nodes[node])
|
285 |
+
# Get entity_type as labels
|
286 |
+
labels = []
|
287 |
+
if "entity_type" in node_data:
|
288 |
+
if isinstance(node_data["entity_type"], list):
|
289 |
+
labels.extend(node_data["entity_type"])
|
290 |
+
else:
|
291 |
+
labels.append(node_data["entity_type"])
|
292 |
+
|
293 |
+
# Create node with properties
|
294 |
+
node_properties = {k: v for k, v in node_data.items()}
|
295 |
+
|
296 |
+
result.nodes.append(
|
297 |
+
KnowledgeGraphNode(
|
298 |
+
id=str(node), labels=[str(node)], properties=node_properties
|
299 |
+
)
|
300 |
)
|
301 |
+
seen_nodes.add(str(node))
|
302 |
+
|
303 |
+
# Add edges to result
|
304 |
+
for edge in subgraph.edges():
|
305 |
+
source, target = edge
|
306 |
+
edge_id = f"{source}-{target}"
|
307 |
+
if edge_id in seen_edges:
|
308 |
+
continue
|
309 |
+
|
310 |
+
edge_data = dict(subgraph.edges[edge])
|
311 |
+
|
312 |
+
# Create edge with complete information
|
313 |
+
result.edges.append(
|
314 |
+
KnowledgeGraphEdge(
|
315 |
+
id=edge_id,
|
316 |
+
type="DIRECTED",
|
317 |
+
source=str(source),
|
318 |
+
target=str(target),
|
319 |
+
properties=edge_data,
|
320 |
+
)
|
321 |
)
|
322 |
+
seen_edges.add(edge_id)
|
|
|
|
|
|
|
323 |
|
324 |
logger.info(
|
325 |
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
lightrag/kg/shared_storage.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocessing.synchronize import Lock as ProcessLock
|
2 |
+
from threading import Lock as ThreadLock
|
3 |
+
from multiprocessing import Manager
|
4 |
+
from typing import Any, Dict, Optional, Union
|
5 |
+
|
6 |
+
# 定义类型变量
|
7 |
+
LockType = Union[ProcessLock, ThreadLock]
|
8 |
+
|
9 |
+
# 全局变量
|
10 |
+
_shared_data: Optional[Dict[str, Any]] = None
|
11 |
+
_namespace_objects: Optional[Dict[str, Any]] = None
|
12 |
+
_global_lock: Optional[LockType] = None
|
13 |
+
is_multiprocess = False
|
14 |
+
manager = None
|
15 |
+
|
16 |
+
def initialize_manager():
|
17 |
+
"""Initialize manager, only for multiple processes where workers > 1"""
|
18 |
+
global manager
|
19 |
+
if manager is None:
|
20 |
+
manager = Manager()
|
21 |
+
|
22 |
+
def _get_global_lock() -> LockType:
|
23 |
+
global _global_lock, is_multiprocess
|
24 |
+
|
25 |
+
if _global_lock is None:
|
26 |
+
if is_multiprocess:
|
27 |
+
_global_lock = manager.Lock()
|
28 |
+
else:
|
29 |
+
_global_lock = ThreadLock()
|
30 |
+
|
31 |
+
return _global_lock
|
32 |
+
|
33 |
+
def get_storage_lock() -> LockType:
|
34 |
+
"""return storage lock for data consistency"""
|
35 |
+
return _get_global_lock()
|
36 |
+
|
37 |
+
def get_scan_lock() -> LockType:
|
38 |
+
"""return scan_progress lock for data consistency"""
|
39 |
+
return get_storage_lock()
|
40 |
+
|
41 |
+
def get_shared_data() -> Dict[str, Any]:
|
42 |
+
"""
|
43 |
+
return shared data for all storage types
|
44 |
+
create mult-process save share data only if need for better performance
|
45 |
+
"""
|
46 |
+
global _shared_data, is_multiprocess
|
47 |
+
|
48 |
+
if _shared_data is None:
|
49 |
+
lock = _get_global_lock()
|
50 |
+
with lock:
|
51 |
+
if _shared_data is None:
|
52 |
+
if is_multiprocess:
|
53 |
+
_shared_data = manager.dict()
|
54 |
+
else:
|
55 |
+
_shared_data = {}
|
56 |
+
|
57 |
+
return _shared_data
|
58 |
+
|
59 |
+
def get_namespace_object(namespace: str) -> Any:
|
60 |
+
"""Get an object for specific namespace"""
|
61 |
+
global _namespace_objects, is_multiprocess
|
62 |
+
|
63 |
+
if _namespace_objects is None:
|
64 |
+
lock = _get_global_lock()
|
65 |
+
with lock:
|
66 |
+
if _namespace_objects is None:
|
67 |
+
_namespace_objects = {}
|
68 |
+
|
69 |
+
if namespace not in _namespace_objects:
|
70 |
+
lock = _get_global_lock()
|
71 |
+
with lock:
|
72 |
+
if namespace not in _namespace_objects:
|
73 |
+
if is_multiprocess:
|
74 |
+
_namespace_objects[namespace] = manager.Value('O', None)
|
75 |
+
else:
|
76 |
+
_namespace_objects[namespace] = None
|
77 |
+
|
78 |
+
return _namespace_objects[namespace]
|
79 |
+
|
80 |
+
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
81 |
+
"""get storage space for specific storage type(namespace)"""
|
82 |
+
shared_data = get_shared_data()
|
83 |
+
lock = _get_global_lock()
|
84 |
+
|
85 |
+
if namespace not in shared_data:
|
86 |
+
with lock:
|
87 |
+
if namespace not in shared_data:
|
88 |
+
shared_data[namespace] = {}
|
89 |
+
|
90 |
+
return shared_data[namespace]
|
91 |
+
|
92 |
+
def get_scan_progress() -> Dict[str, Any]:
|
93 |
+
"""get storage space for document scanning progress data"""
|
94 |
+
return get_namespace_data('scan_progress')
|
lightrag/lightrag.py
CHANGED
@@ -266,13 +266,7 @@ class LightRAG:
|
|
266 |
|
267 |
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
268 |
|
269 |
-
def __post_init__(self):
|
270 |
-
# Initialize manager if needed
|
271 |
-
from lightrag.api.utils_api import manager, initialize_manager
|
272 |
-
if manager is None:
|
273 |
-
initialize_manager()
|
274 |
-
logger.info("Initialized manager for single process mode")
|
275 |
-
|
276 |
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
|
277 |
set_logger(self.log_file_path, self.log_level)
|
278 |
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
|
|
266 |
|
267 |
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
|
268 |
|
269 |
+
def __post_init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
|
271 |
set_logger(self.log_file_path, self.log_level)
|
272 |
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|