yangdx commited on
Commit
980a2a9
·
1 Parent(s): ccc35ac

feat(storage): Add shared memory support for file-based storage implementations

Browse files

This commit adds multiprocessing shared memory support to file-based storage implementations:
- JsonDocStatusStorage
- JsonKVStorage
- NanoVectorDBStorage
- NetworkXStorage

Each storage module now uses module-level global variables with multiprocessing.Manager() to ensure data consistency across multiple uvicorn workers. All processes will see
updates immediately when data is modified through ainsert function.

lightrag/kg/json_doc_status_impl.py CHANGED
@@ -1,6 +1,8 @@
1
  from dataclasses import dataclass
2
  import os
3
  from typing import Any, Union, final
 
 
4
 
5
  from lightrag.base import (
6
  DocProcessingStatus,
@@ -13,6 +15,25 @@ from lightrag.utils import (
13
  write_json,
14
  )
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @final
18
  @dataclass
@@ -22,8 +43,27 @@ class JsonDocStatusStorage(DocStatusStorage):
22
  def __post_init__(self):
23
  working_dir = self.global_config["working_dir"]
24
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
25
- self._data: dict[str, Any] = load_json(self._file_name) or {}
26
- logger.info(f"Loaded document status storage with {len(self._data)} records")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  async def filter_keys(self, keys: set[str]) -> set[str]:
29
  """Return keys that should be processed (not in storage or not successfully processed)"""
 
1
  from dataclasses import dataclass
2
  import os
3
  from typing import Any, Union, final
4
+ import threading
5
+ from multiprocessing import Manager
6
 
7
  from lightrag.base import (
8
  DocProcessingStatus,
 
15
  write_json,
16
  )
17
 
18
+ # Global variables for shared memory management
19
+ _init_lock = threading.Lock()
20
+ _manager = None
21
+ _shared_doc_status_data = None
22
+
23
+
24
+ def _get_manager():
25
+ """Get or create the global manager instance"""
26
+ global _manager, _shared_doc_status_data
27
+ with _init_lock:
28
+ if _manager is None:
29
+ try:
30
+ _manager = Manager()
31
+ _shared_doc_status_data = _manager.dict()
32
+ except Exception as e:
33
+ logger.error(f"Failed to initialize shared memory manager: {e}")
34
+ raise RuntimeError(f"Shared memory initialization failed: {e}")
35
+ return _manager
36
+
37
 
38
  @final
39
  @dataclass
 
43
  def __post_init__(self):
44
  working_dir = self.global_config["working_dir"]
45
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
46
+
47
+ # Ensure manager is initialized
48
+ _get_manager()
49
+
50
+ # Get or create namespace data
51
+ if self.namespace not in _shared_doc_status_data:
52
+ with _init_lock:
53
+ if self.namespace not in _shared_doc_status_data:
54
+ try:
55
+ initial_data = load_json(self._file_name) or {}
56
+ _shared_doc_status_data[self.namespace] = initial_data
57
+ except Exception as e:
58
+ logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
59
+ raise RuntimeError(f"Shared data initialization failed: {e}")
60
+
61
+ try:
62
+ self._data = _shared_doc_status_data[self.namespace]
63
+ logger.info(f"Loaded document status storage with {len(self._data)} records")
64
+ except Exception as e:
65
+ logger.error(f"Failed to access shared memory: {e}")
66
+ raise RuntimeError(f"Cannot access shared memory: {e}")
67
 
68
  async def filter_keys(self, keys: set[str]) -> set[str]:
69
  """Return keys that should be processed (not in storage or not successfully processed)"""
lightrag/kg/json_kv_impl.py CHANGED
@@ -2,6 +2,8 @@ import asyncio
2
  import os
3
  from dataclasses import dataclass
4
  from typing import Any, final
 
 
5
 
6
  from lightrag.base import (
7
  BaseKVStorage,
@@ -12,6 +14,25 @@ from lightrag.utils import (
12
  write_json,
13
  )
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  @final
17
  @dataclass
@@ -19,9 +40,28 @@ class JsonKVStorage(BaseKVStorage):
19
  def __post_init__(self):
20
  working_dir = self.global_config["working_dir"]
21
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
22
- self._data: dict[str, Any] = load_json(self._file_name) or {}
23
  self._lock = asyncio.Lock()
24
- logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  async def index_done_callback(self) -> None:
27
  write_json(self._data, self._file_name)
 
2
  import os
3
  from dataclasses import dataclass
4
  from typing import Any, final
5
+ import threading
6
+ from multiprocessing import Manager
7
 
8
  from lightrag.base import (
9
  BaseKVStorage,
 
14
  write_json,
15
  )
16
 
17
+ # Global variables for shared memory management
18
+ _init_lock = threading.Lock()
19
+ _manager = None
20
+ _shared_kv_data = None
21
+
22
+
23
+ def _get_manager():
24
+ """Get or create the global manager instance"""
25
+ global _manager, _shared_kv_data
26
+ with _init_lock:
27
+ if _manager is None:
28
+ try:
29
+ _manager = Manager()
30
+ _shared_kv_data = _manager.dict()
31
+ except Exception as e:
32
+ logger.error(f"Failed to initialize shared memory manager: {e}")
33
+ raise RuntimeError(f"Shared memory initialization failed: {e}")
34
+ return _manager
35
+
36
 
37
  @final
38
  @dataclass
 
40
  def __post_init__(self):
41
  working_dir = self.global_config["working_dir"]
42
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
 
43
  self._lock = asyncio.Lock()
44
+
45
+ # Ensure manager is initialized
46
+ _get_manager()
47
+
48
+ # Get or create namespace data
49
+ if self.namespace not in _shared_kv_data:
50
+ with _init_lock:
51
+ if self.namespace not in _shared_kv_data:
52
+ try:
53
+ initial_data = load_json(self._file_name) or {}
54
+ _shared_kv_data[self.namespace] = initial_data
55
+ except Exception as e:
56
+ logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}")
57
+ raise RuntimeError(f"Shared data initialization failed: {e}")
58
+
59
+ try:
60
+ self._data = _shared_kv_data[self.namespace]
61
+ logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
62
+ except Exception as e:
63
+ logger.error(f"Failed to access shared memory: {e}")
64
+ raise RuntimeError(f"Cannot access shared memory: {e}")
65
 
66
  async def index_done_callback(self) -> None:
67
  write_json(self._data, self._file_name)
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -3,6 +3,8 @@ import os
3
  from typing import Any, final
4
  from dataclasses import dataclass
5
  import numpy as np
 
 
6
 
7
  import time
8
 
@@ -20,6 +22,25 @@ if not pm.is_installed("nano-vectordb"):
20
 
21
  from nano_vectordb import NanoVectorDB
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @final
25
  @dataclass
@@ -40,9 +61,29 @@ class NanoVectorDBStorage(BaseVectorStorage):
40
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
41
  )
42
  self._max_batch_size = self.global_config["embedding_batch_num"]
43
- self._client = NanoVectorDB(
44
- self.embedding_func.embedding_dim, storage_file=self._client_file_name
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
48
  logger.info(f"Inserting {len(data)} to {self.namespace}")
 
3
  from typing import Any, final
4
  from dataclasses import dataclass
5
  import numpy as np
6
+ import threading
7
+ from multiprocessing import Manager
8
 
9
  import time
10
 
 
22
 
23
  from nano_vectordb import NanoVectorDB
24
 
25
+ # Global variables for shared memory management
26
+ _init_lock = threading.Lock()
27
+ _manager = None
28
+ _shared_vector_clients = None
29
+
30
+
31
+ def _get_manager():
32
+ """Get or create the global manager instance"""
33
+ global _manager, _shared_vector_clients
34
+ with _init_lock:
35
+ if _manager is None:
36
+ try:
37
+ _manager = Manager()
38
+ _shared_vector_clients = _manager.dict()
39
+ except Exception as e:
40
+ logger.error(f"Failed to initialize shared memory manager: {e}")
41
+ raise RuntimeError(f"Shared memory initialization failed: {e}")
42
+ return _manager
43
+
44
 
45
  @final
46
  @dataclass
 
61
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
62
  )
63
  self._max_batch_size = self.global_config["embedding_batch_num"]
64
+
65
+ # Ensure manager is initialized
66
+ _get_manager()
67
+
68
+ # Get or create namespace client
69
+ if self.namespace not in _shared_vector_clients:
70
+ with _init_lock:
71
+ if self.namespace not in _shared_vector_clients:
72
+ try:
73
+ client = NanoVectorDB(
74
+ self.embedding_func.embedding_dim,
75
+ storage_file=self._client_file_name
76
+ )
77
+ _shared_vector_clients[self.namespace] = client
78
+ except Exception as e:
79
+ logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}")
80
+ raise RuntimeError(f"Vector DB client initialization failed: {e}")
81
+
82
+ try:
83
+ self._client = _shared_vector_clients[self.namespace]
84
+ except Exception as e:
85
+ logger.error(f"Failed to access shared memory: {e}")
86
+ raise RuntimeError(f"Cannot access shared memory: {e}")
87
 
88
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
89
  logger.info(f"Inserting {len(data)} to {self.namespace}")
lightrag/kg/networkx_impl.py CHANGED
@@ -1,10 +1,11 @@
1
  import os
2
  from dataclasses import dataclass
3
  from typing import Any, final
 
 
4
 
5
  import numpy as np
6
 
7
-
8
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
9
  from lightrag.utils import (
10
  logger,
@@ -24,6 +25,25 @@ if not pm.is_installed("graspologic"):
24
  import networkx as nx
25
  from graspologic import embed
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  @final
29
  @dataclass
@@ -78,15 +98,33 @@ class NetworkXStorage(BaseGraphStorage):
78
  self._graphml_xml_file = os.path.join(
79
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
80
  )
81
- preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
82
- if preloaded_graph is not None:
83
- logger.info(
84
- f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
85
- )
86
- self._graph = preloaded_graph or nx.Graph()
87
- self._node_embed_algorithms = {
88
- "node2vec": self._node2vec_embed,
89
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  async def index_done_callback(self) -> None:
92
  NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
 
1
  import os
2
  from dataclasses import dataclass
3
  from typing import Any, final
4
+ import threading
5
+ from multiprocessing import Manager
6
 
7
  import numpy as np
8
 
 
9
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
10
  from lightrag.utils import (
11
  logger,
 
25
  import networkx as nx
26
  from graspologic import embed
27
 
28
+ # Global variables for shared memory management
29
+ _init_lock = threading.Lock()
30
+ _manager = None
31
+ _shared_graphs = None
32
+
33
+
34
+ def _get_manager():
35
+ """Get or create the global manager instance"""
36
+ global _manager, _shared_graphs
37
+ with _init_lock:
38
+ if _manager is None:
39
+ try:
40
+ _manager = Manager()
41
+ _shared_graphs = _manager.dict()
42
+ except Exception as e:
43
+ logger.error(f"Failed to initialize shared memory manager: {e}")
44
+ raise RuntimeError(f"Shared memory initialization failed: {e}")
45
+ return _manager
46
+
47
 
48
  @final
49
  @dataclass
 
98
  self._graphml_xml_file = os.path.join(
99
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
100
  )
101
+
102
+ # Ensure manager is initialized
103
+ _get_manager()
104
+
105
+ # Get or create namespace graph
106
+ if self.namespace not in _shared_graphs:
107
+ with _init_lock:
108
+ if self.namespace not in _shared_graphs:
109
+ try:
110
+ preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
111
+ if preloaded_graph is not None:
112
+ logger.info(
113
+ f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
114
+ )
115
+ _shared_graphs[self.namespace] = preloaded_graph or nx.Graph()
116
+ except Exception as e:
117
+ logger.error(f"Failed to initialize graph for namespace {self.namespace}: {e}")
118
+ raise RuntimeError(f"Graph initialization failed: {e}")
119
+
120
+ try:
121
+ self._graph = _shared_graphs[self.namespace]
122
+ self._node_embed_algorithms = {
123
+ "node2vec": self._node2vec_embed,
124
+ }
125
+ except Exception as e:
126
+ logger.error(f"Failed to access shared memory: {e}")
127
+ raise RuntimeError(f"Cannot access shared memory: {e}")
128
 
129
  async def index_done_callback(self) -> None:
130
  NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)