yangdx commited on
Commit
f5df192
·
1 Parent(s): 9047e41

Refactor initialization logic for vector, KV and graph storage implementations

Browse files

• Add try_initialize_namespace check
• Move init code out of storage locks
• Reduce redundant init conditions
• Simplify initialization flow
• Make init thread-safer

lightrag/kg/faiss_impl.py CHANGED
@@ -15,6 +15,7 @@ from .shared_storage import (
15
  get_storage_lock,
16
  get_namespace_object,
17
  is_multiprocess,
 
18
  )
19
 
20
  if not pm.is_installed("faiss"):
@@ -52,26 +53,26 @@ class FaissVectorDBStorage(BaseVectorStorage):
52
  self._dim = self.embedding_func.embedding_dim
53
  self._storage_lock = get_storage_lock()
54
 
 
 
55
  self._index = get_namespace_object("faiss_indices")
56
  self._id_to_meta = get_namespace_data("faiss_meta")
57
 
58
- with self._storage_lock:
59
  if is_multiprocess:
60
- if self._index.value is None:
61
- # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
62
- # If you have a large number of vectors, you might want IVF or other indexes.
63
- # For demonstration, we use a simple IndexFlatIP.
64
- self._index.value = faiss.IndexFlatIP(self._dim)
65
- # Keep a local store for metadata, IDs, etc.
66
- # Maps <int faiss_id> → metadata (including your original ID).
67
- self._id_to_meta.update({})
68
- # Attempt to load an existing index + metadata from disk
69
- self._load_faiss_index()
70
  else:
71
- if self._index is None:
72
- self._index = faiss.IndexFlatIP(self._dim)
73
- self._id_to_meta.update({})
74
- self._load_faiss_index()
75
 
76
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
77
  """
 
15
  get_storage_lock,
16
  get_namespace_object,
17
  is_multiprocess,
18
+ try_initialize_namespace,
19
  )
20
 
21
  if not pm.is_installed("faiss"):
 
53
  self._dim = self.embedding_func.embedding_dim
54
  self._storage_lock = get_storage_lock()
55
 
56
+ # check need_init must before get_namespace_object/get_namespace_data
57
+ need_init = try_initialize_namespace("faiss_indices")
58
  self._index = get_namespace_object("faiss_indices")
59
  self._id_to_meta = get_namespace_data("faiss_meta")
60
 
61
+ if need_init:
62
  if is_multiprocess:
63
+ # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
64
+ # If you have a large number of vectors, you might want IVF or other indexes.
65
+ # For demonstration, we use a simple IndexFlatIP.
66
+ self._index.value = faiss.IndexFlatIP(self._dim)
67
+ # Keep a local store for metadata, IDs, etc.
68
+ # Maps <int faiss_id> metadata (including your original ID).
69
+ self._id_to_meta.update({})
70
+ # Attempt to load an existing index + metadata from disk
71
+ self._load_faiss_index()
 
72
  else:
73
+ self._index = faiss.IndexFlatIP(self._dim)
74
+ self._id_to_meta.update({})
75
+ self._load_faiss_index()
 
76
 
77
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
78
  """
lightrag/kg/json_kv_impl.py CHANGED
@@ -10,7 +10,7 @@ from lightrag.utils import (
10
  logger,
11
  write_json,
12
  )
13
- from .shared_storage import get_namespace_data, get_storage_lock
14
 
15
 
16
  @final
@@ -20,11 +20,15 @@ class JsonKVStorage(BaseKVStorage):
20
  working_dir = self.global_config["working_dir"]
21
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
22
  self._storage_lock = get_storage_lock()
 
 
 
23
  self._data = get_namespace_data(self.namespace)
24
- with self._storage_lock:
25
- if not self._data:
26
- self._data: dict[str, Any] = load_json(self._file_name) or {}
27
- logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
 
28
 
29
  async def index_done_callback(self) -> None:
30
  # 文件写入需要加锁,防止多个进程同时写入导致文件损坏
 
10
  logger,
11
  write_json,
12
  )
13
+ from .shared_storage import get_namespace_data, get_storage_lock, try_initialize_namespace
14
 
15
 
16
  @final
 
20
  working_dir = self.global_config["working_dir"]
21
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
22
  self._storage_lock = get_storage_lock()
23
+
24
+ # check need_init must before get_namespace_data
25
+ need_init = try_initialize_namespace(self.namespace)
26
  self._data = get_namespace_data(self.namespace)
27
+ if need_init:
28
+ loaded_data = load_json(self._file_name) or {}
29
+ with self._storage_lock:
30
+ self._data.update(loaded_data)
31
+ logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
32
 
33
  async def index_done_callback(self) -> None:
34
  # 文件写入需要加锁,防止多个进程同时写入导致文件损坏
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -11,7 +11,7 @@ from lightrag.utils import (
11
  )
12
  import pipmaster as pm
13
  from lightrag.base import BaseVectorStorage
14
- from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
15
 
16
  if not pm.is_installed("nano-vectordb"):
17
  pm.install("nano-vectordb")
@@ -40,27 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage):
40
  )
41
  self._max_batch_size = self.global_config["embedding_batch_num"]
42
 
 
 
43
  self._client = get_namespace_object(self.namespace)
44
 
45
- with self._storage_lock:
46
  if is_multiprocess:
47
- if self._client.value is None:
48
- self._client.value = NanoVectorDB(
49
- self.embedding_func.embedding_dim,
50
- storage_file=self._client_file_name,
51
- )
52
- logger.info(
53
- f"Initialized vector DB client for namespace {self.namespace}"
54
- )
55
  else:
56
- if self._client is None:
57
- self._client = NanoVectorDB(
58
- self.embedding_func.embedding_dim,
59
- storage_file=self._client_file_name,
60
- )
61
- logger.info(
62
- f"Initialized vector DB client for namespace {self.namespace}"
63
- )
64
 
65
  def _get_client(self):
66
  """Get the appropriate client instance based on multiprocess mode"""
 
11
  )
12
  import pipmaster as pm
13
  from lightrag.base import BaseVectorStorage
14
+ from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace
15
 
16
  if not pm.is_installed("nano-vectordb"):
17
  pm.install("nano-vectordb")
 
40
  )
41
  self._max_batch_size = self.global_config["embedding_batch_num"]
42
 
43
+ # check need_init must before get_namespace_object
44
+ need_init = try_initialize_namespace(self.namespace)
45
  self._client = get_namespace_object(self.namespace)
46
 
47
+ if need_init:
48
  if is_multiprocess:
49
+ self._client.value = NanoVectorDB(
50
+ self.embedding_func.embedding_dim,
51
+ storage_file=self._client_file_name,
52
+ )
53
+ logger.info(
54
+ f"Initialized vector DB client for namespace {self.namespace}"
55
+ )
 
56
  else:
57
+ self._client = NanoVectorDB(
58
+ self.embedding_func.embedding_dim,
59
+ storage_file=self._client_file_name,
60
+ )
61
+ logger.info(
62
+ f"Initialized vector DB client for namespace {self.namespace}"
63
+ )
 
64
 
65
  def _get_client(self):
66
  """Get the appropriate client instance based on multiprocess mode"""
lightrag/kg/networkx_impl.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
7
  from lightrag.utils import logger
8
  from lightrag.base import BaseGraphStorage
9
- from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess
10
 
11
  import pipmaster as pm
12
 
@@ -74,32 +74,34 @@ class NetworkXStorage(BaseGraphStorage):
74
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
75
  )
76
  self._storage_lock = get_storage_lock()
 
 
 
77
  self._graph = get_namespace_object(self.namespace)
78
- with self._storage_lock:
 
79
  if is_multiprocess:
80
- if self._graph.value is None:
81
- preloaded_graph = NetworkXStorage.load_nx_graph(
82
- self._graphml_xml_file
 
 
 
 
83
  )
84
- self._graph.value = preloaded_graph or nx.Graph()
85
- if preloaded_graph:
86
- logger.info(
87
- f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
88
- )
89
- else:
90
- logger.info("Created new empty graph")
91
  else:
92
- if self._graph is None:
93
- preloaded_graph = NetworkXStorage.load_nx_graph(
94
- self._graphml_xml_file
 
 
 
 
95
  )
96
- self._graph = preloaded_graph or nx.Graph()
97
- if preloaded_graph:
98
- logger.info(
99
- f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
100
- )
101
- else:
102
- logger.info("Created new empty graph")
103
 
104
  self._node_embed_algorithms = {
105
  "node2vec": self._node2vec_embed,
 
6
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
7
  from lightrag.utils import logger
8
  from lightrag.base import BaseGraphStorage
9
+ from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace
10
 
11
  import pipmaster as pm
12
 
 
74
  self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
75
  )
76
  self._storage_lock = get_storage_lock()
77
+
78
+ # check need_init must before get_namespace_object
79
+ need_init = try_initialize_namespace(self.namespace)
80
  self._graph = get_namespace_object(self.namespace)
81
+
82
+ if need_init:
83
  if is_multiprocess:
84
+ preloaded_graph = NetworkXStorage.load_nx_graph(
85
+ self._graphml_xml_file
86
+ )
87
+ self._graph.value = preloaded_graph or nx.Graph()
88
+ if preloaded_graph:
89
+ logger.info(
90
+ f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
91
  )
92
+ else:
93
+ logger.info("Created new empty graph")
 
 
 
 
 
94
  else:
95
+ preloaded_graph = NetworkXStorage.load_nx_graph(
96
+ self._graphml_xml_file
97
+ )
98
+ self._graph = preloaded_graph or nx.Graph()
99
+ if preloaded_graph:
100
+ logger.info(
101
+ f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
102
  )
103
+ else:
104
+ logger.info("Created new empty graph")
 
 
 
 
 
105
 
106
  self._node_embed_algorithms = {
107
  "node2vec": self._node2vec_embed,