yangdx
commited on
Commit
·
8a706cb
1
Parent(s):
911c794
revert vector and graph use local data(single process)
Browse files- lightrag/kg/faiss_impl.py +132 -170
- lightrag/kg/nano_vector_db_impl.py +46 -75
- lightrag/kg/networkx_impl.py +122 -158
- lightrag/kg/shared_storage.py +4 -57
lightrag/kg/faiss_impl.py
CHANGED
@@ -10,19 +10,12 @@ import pipmaster as pm
|
|
10 |
|
11 |
from lightrag.utils import logger, compute_mdhash_id
|
12 |
from lightrag.base import BaseVectorStorage
|
13 |
-
from .shared_storage import (
|
14 |
-
get_namespace_data,
|
15 |
-
get_storage_lock,
|
16 |
-
get_namespace_object,
|
17 |
-
is_multiprocess,
|
18 |
-
try_initialize_namespace,
|
19 |
-
)
|
20 |
|
21 |
if not pm.is_installed("faiss"):
|
22 |
pm.install("faiss")
|
23 |
|
24 |
import faiss # type: ignore
|
25 |
-
|
26 |
|
27 |
@final
|
28 |
@dataclass
|
@@ -51,35 +44,29 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
51 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
52 |
# Embedding dimension (e.g. 768) must match your embedding function
|
53 |
self._dim = self.embedding_func.embedding_dim
|
54 |
-
self._storage_lock =
|
55 |
-
|
56 |
-
#
|
57 |
-
|
58 |
-
|
59 |
-
self.
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
self._id_to_meta.update({})
|
70 |
-
# Attempt to load an existing index + metadata from disk
|
71 |
-
self._load_faiss_index()
|
72 |
-
else:
|
73 |
-
self._index = faiss.IndexFlatIP(self._dim)
|
74 |
-
self._id_to_meta.update({})
|
75 |
-
self._load_faiss_index()
|
76 |
|
77 |
def _get_index(self):
|
78 |
-
"""
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
83 |
|
84 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
85 |
"""
|
@@ -134,34 +121,33 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
134 |
# Normalize embeddings for cosine similarity (in-place)
|
135 |
faiss.normalize_L2(embeddings)
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
return [m["__id__"] for m in list_data]
|
165 |
|
166 |
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
167 |
"""
|
@@ -177,57 +163,54 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
177 |
)
|
178 |
|
179 |
# Perform the similarity search
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
)
|
205 |
|
206 |
-
|
207 |
|
208 |
@property
|
209 |
def client_storage(self):
|
210 |
# Return whatever structure LightRAG might need for debugging
|
211 |
-
|
212 |
-
return {"data": list(self._id_to_meta.values())}
|
213 |
|
214 |
async def delete(self, ids: list[str]):
|
215 |
"""
|
216 |
Delete vectors for the provided custom IDs.
|
217 |
"""
|
218 |
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
)
|
231 |
|
232 |
async def delete_entity(self, entity_name: str) -> None:
|
233 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
@@ -239,23 +222,18 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
239 |
Delete relations for a given entity by scanning metadata.
|
240 |
"""
|
241 |
logger.debug(f"Searching relations for entity {entity_name}")
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
)
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
255 |
-
|
256 |
-
async def index_done_callback(self) -> None:
|
257 |
-
with self._storage_lock:
|
258 |
-
self._save_faiss_index()
|
259 |
|
260 |
# --------------------------------------------------------------------------------
|
261 |
# Internal helper methods
|
@@ -265,11 +243,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
265 |
"""
|
266 |
Return the Faiss internal ID for a given custom ID, or None if not found.
|
267 |
"""
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
return None
|
273 |
|
274 |
def _remove_faiss_ids(self, fid_list):
|
275 |
"""
|
@@ -277,48 +254,42 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
277 |
Because IndexFlatIP doesn't support 'removals',
|
278 |
we rebuild the index excluding those vectors.
|
279 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
with self._storage_lock:
|
281 |
-
|
282 |
-
|
283 |
-
# Rebuild the index
|
284 |
-
vectors_to_keep = []
|
285 |
-
new_id_to_meta = {}
|
286 |
-
for new_fid, old_fid in enumerate(keep_fids):
|
287 |
-
vec_meta = self._id_to_meta[old_fid]
|
288 |
-
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
|
289 |
-
new_id_to_meta[new_fid] = vec_meta
|
290 |
-
|
291 |
-
# Re-init index
|
292 |
-
new_index = faiss.IndexFlatIP(self._dim)
|
293 |
if vectors_to_keep:
|
294 |
arr = np.array(vectors_to_keep, dtype=np.float32)
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
else:
|
299 |
-
self._index = new_index
|
300 |
|
301 |
-
self._id_to_meta.update(new_id_to_meta)
|
302 |
|
303 |
def _save_faiss_index(self):
|
304 |
"""
|
305 |
Save the current Faiss index + metadata to disk so it can persist across runs.
|
306 |
"""
|
307 |
-
|
308 |
-
faiss.write_index(
|
309 |
-
self._get_index(),
|
310 |
-
self._faiss_index_file,
|
311 |
-
)
|
312 |
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
|
|
|
|
|
|
319 |
|
320 |
-
with open(self._meta_file, "w", encoding="utf-8") as f:
|
321 |
-
json.dump(serializable_dict, f)
|
322 |
|
323 |
def _load_faiss_index(self):
|
324 |
"""
|
@@ -331,31 +302,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
331 |
|
332 |
try:
|
333 |
# Load the Faiss index
|
334 |
-
|
335 |
-
if is_multiprocess:
|
336 |
-
self._index.value = loaded_index
|
337 |
-
else:
|
338 |
-
self._index = loaded_index
|
339 |
-
|
340 |
# Load metadata
|
341 |
with open(self._meta_file, "r", encoding="utf-8") as f:
|
342 |
stored_dict = json.load(f)
|
343 |
|
344 |
# Convert string keys back to int
|
345 |
-
self._id_to_meta
|
346 |
for fid_str, meta in stored_dict.items():
|
347 |
fid = int(fid_str)
|
348 |
self._id_to_meta[fid] = meta
|
349 |
|
350 |
logger.info(
|
351 |
-
f"Faiss index loaded with {
|
352 |
)
|
353 |
except Exception as e:
|
354 |
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
355 |
logger.warning("Starting with an empty Faiss index.")
|
356 |
-
|
357 |
-
|
358 |
-
self._index.value = new_index
|
359 |
-
else:
|
360 |
-
self._index = new_index
|
361 |
-
self._id_to_meta.update({})
|
|
|
10 |
|
11 |
from lightrag.utils import logger, compute_mdhash_id
|
12 |
from lightrag.base import BaseVectorStorage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
if not pm.is_installed("faiss"):
|
15 |
pm.install("faiss")
|
16 |
|
17 |
import faiss # type: ignore
|
18 |
+
from threading import Lock as ThreadLock
|
19 |
|
20 |
@final
|
21 |
@dataclass
|
|
|
44 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
45 |
# Embedding dimension (e.g. 768) must match your embedding function
|
46 |
self._dim = self.embedding_func.embedding_dim
|
47 |
+
self._storage_lock = ThreadLock()
|
48 |
+
|
49 |
+
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
50 |
+
# If you have a large number of vectors, you might want IVF or other indexes.
|
51 |
+
# For demonstration, we use a simple IndexFlatIP.
|
52 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
53 |
+
|
54 |
+
# Keep a local store for metadata, IDs, etc.
|
55 |
+
# Maps <int faiss_id> → metadata (including your original ID).
|
56 |
+
self._id_to_meta = {}
|
57 |
+
|
58 |
+
# Attempt to load an existing index + metadata from disk
|
59 |
+
with self._storage_lock:
|
60 |
+
self._load_faiss_index()
|
61 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
def _get_index(self):
|
64 |
+
"""Check if the shtorage should be reloaded"""
|
65 |
+
return self._index
|
66 |
+
|
67 |
+
async def index_done_callback(self) -> None:
|
68 |
+
with self._storage_lock:
|
69 |
+
self._save_faiss_index()
|
70 |
|
71 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
72 |
"""
|
|
|
121 |
# Normalize embeddings for cosine similarity (in-place)
|
122 |
faiss.normalize_L2(embeddings)
|
123 |
|
124 |
+
# Upsert logic:
|
125 |
+
# 1. Identify which vectors to remove if they exist
|
126 |
+
# 2. Remove them
|
127 |
+
# 3. Add the new vectors
|
128 |
+
existing_ids_to_remove = []
|
129 |
+
for meta, emb in zip(list_data, embeddings):
|
130 |
+
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
|
131 |
+
if faiss_internal_id is not None:
|
132 |
+
existing_ids_to_remove.append(faiss_internal_id)
|
133 |
+
|
134 |
+
if existing_ids_to_remove:
|
135 |
+
self._remove_faiss_ids(existing_ids_to_remove)
|
136 |
+
|
137 |
+
# Step 2: Add new vectors
|
138 |
+
index = self._get_index()
|
139 |
+
start_idx = index.ntotal
|
140 |
+
index.add(embeddings)
|
141 |
+
|
142 |
+
# Step 3: Store metadata + vector for each new ID
|
143 |
+
for i, meta in enumerate(list_data):
|
144 |
+
fid = start_idx + i
|
145 |
+
# Store the raw vector so we can rebuild if something is removed
|
146 |
+
meta["__vector__"] = embeddings[i].tolist()
|
147 |
+
self._id_to_meta.update({fid: meta})
|
148 |
+
|
149 |
+
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
150 |
+
return [m["__id__"] for m in list_data]
|
|
|
151 |
|
152 |
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
153 |
"""
|
|
|
163 |
)
|
164 |
|
165 |
# Perform the similarity search
|
166 |
+
distances, indices = self._get_index().search(embedding, top_k)
|
167 |
+
|
168 |
+
distances = distances[0]
|
169 |
+
indices = indices[0]
|
170 |
+
|
171 |
+
results = []
|
172 |
+
for dist, idx in zip(distances, indices):
|
173 |
+
if idx == -1:
|
174 |
+
# Faiss returns -1 if no neighbor
|
175 |
+
continue
|
176 |
+
|
177 |
+
# Cosine similarity threshold
|
178 |
+
if dist < self.cosine_better_than_threshold:
|
179 |
+
continue
|
180 |
+
|
181 |
+
meta = self._id_to_meta.get(idx, {})
|
182 |
+
results.append(
|
183 |
+
{
|
184 |
+
**meta,
|
185 |
+
"id": meta.get("__id__"),
|
186 |
+
"distance": float(dist),
|
187 |
+
"created_at": meta.get("__created_at__"),
|
188 |
+
}
|
189 |
+
)
|
|
|
190 |
|
191 |
+
return results
|
192 |
|
193 |
@property
|
194 |
def client_storage(self):
|
195 |
# Return whatever structure LightRAG might need for debugging
|
196 |
+
return {"data": list(self._id_to_meta.values())}
|
|
|
197 |
|
198 |
async def delete(self, ids: list[str]):
|
199 |
"""
|
200 |
Delete vectors for the provided custom IDs.
|
201 |
"""
|
202 |
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
203 |
+
to_remove = []
|
204 |
+
for cid in ids:
|
205 |
+
fid = self._find_faiss_id_by_custom_id(cid)
|
206 |
+
if fid is not None:
|
207 |
+
to_remove.append(fid)
|
208 |
+
|
209 |
+
if to_remove:
|
210 |
+
self._remove_faiss_ids(to_remove)
|
211 |
+
logger.debug(
|
212 |
+
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
213 |
+
)
|
|
|
214 |
|
215 |
async def delete_entity(self, entity_name: str) -> None:
|
216 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
|
|
222 |
Delete relations for a given entity by scanning metadata.
|
223 |
"""
|
224 |
logger.debug(f"Searching relations for entity {entity_name}")
|
225 |
+
relations = []
|
226 |
+
for fid, meta in self._id_to_meta.items():
|
227 |
+
if (
|
228 |
+
meta.get("src_id") == entity_name
|
229 |
+
or meta.get("tgt_id") == entity_name
|
230 |
+
):
|
231 |
+
relations.append(fid)
|
232 |
+
|
233 |
+
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
234 |
+
if relations:
|
235 |
+
self._remove_faiss_ids(relations)
|
236 |
+
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
# --------------------------------------------------------------------------------
|
239 |
# Internal helper methods
|
|
|
243 |
"""
|
244 |
Return the Faiss internal ID for a given custom ID, or None if not found.
|
245 |
"""
|
246 |
+
for fid, meta in self._id_to_meta.items():
|
247 |
+
if meta.get("__id__") == custom_id:
|
248 |
+
return fid
|
249 |
+
return None
|
|
|
250 |
|
251 |
def _remove_faiss_ids(self, fid_list):
|
252 |
"""
|
|
|
254 |
Because IndexFlatIP doesn't support 'removals',
|
255 |
we rebuild the index excluding those vectors.
|
256 |
"""
|
257 |
+
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
|
258 |
+
|
259 |
+
# Rebuild the index
|
260 |
+
vectors_to_keep = []
|
261 |
+
new_id_to_meta = {}
|
262 |
+
for new_fid, old_fid in enumerate(keep_fids):
|
263 |
+
vec_meta = self._id_to_meta[old_fid]
|
264 |
+
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
|
265 |
+
new_id_to_meta[new_fid] = vec_meta
|
266 |
+
|
267 |
with self._storage_lock:
|
268 |
+
# Re-init index
|
269 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
if vectors_to_keep:
|
271 |
arr = np.array(vectors_to_keep, dtype=np.float32)
|
272 |
+
self._index.add(arr)
|
273 |
+
|
274 |
+
self._id_to_meta = new_id_to_meta
|
|
|
|
|
275 |
|
|
|
276 |
|
277 |
def _save_faiss_index(self):
|
278 |
"""
|
279 |
Save the current Faiss index + metadata to disk so it can persist across runs.
|
280 |
"""
|
281 |
+
faiss.write_index(self._index, self._faiss_index_file)
|
|
|
|
|
|
|
|
|
282 |
|
283 |
+
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
|
284 |
+
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
|
285 |
+
# We'll keep the int -> dict, but JSON requires string keys.
|
286 |
+
serializable_dict = {}
|
287 |
+
for fid, meta in self._id_to_meta.items():
|
288 |
+
serializable_dict[str(fid)] = meta
|
289 |
+
|
290 |
+
with open(self._meta_file, "w", encoding="utf-8") as f:
|
291 |
+
json.dump(serializable_dict, f)
|
292 |
|
|
|
|
|
293 |
|
294 |
def _load_faiss_index(self):
|
295 |
"""
|
|
|
302 |
|
303 |
try:
|
304 |
# Load the Faiss index
|
305 |
+
self._index = faiss.read_index(self._faiss_index_file)
|
|
|
|
|
|
|
|
|
|
|
306 |
# Load metadata
|
307 |
with open(self._meta_file, "r", encoding="utf-8") as f:
|
308 |
stored_dict = json.load(f)
|
309 |
|
310 |
# Convert string keys back to int
|
311 |
+
self._id_to_meta = {}
|
312 |
for fid_str, meta in stored_dict.items():
|
313 |
fid = int(fid_str)
|
314 |
self._id_to_meta[fid] = meta
|
315 |
|
316 |
logger.info(
|
317 |
+
f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
|
318 |
)
|
319 |
except Exception as e:
|
320 |
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
321 |
logger.warning("Starting with an empty Faiss index.")
|
322 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
323 |
+
self._id_to_meta = {}
|
|
|
|
|
|
|
|
lightrag/kg/nano_vector_db_impl.py
CHANGED
@@ -11,25 +11,19 @@ from lightrag.utils import (
|
|
11 |
)
|
12 |
import pipmaster as pm
|
13 |
from lightrag.base import BaseVectorStorage
|
14 |
-
from .shared_storage import (
|
15 |
-
get_storage_lock,
|
16 |
-
get_namespace_object,
|
17 |
-
is_multiprocess,
|
18 |
-
try_initialize_namespace,
|
19 |
-
)
|
20 |
|
21 |
if not pm.is_installed("nano-vectordb"):
|
22 |
pm.install("nano-vectordb")
|
23 |
|
24 |
from nano_vectordb import NanoVectorDB
|
25 |
-
|
26 |
|
27 |
@final
|
28 |
@dataclass
|
29 |
class NanoVectorDBStorage(BaseVectorStorage):
|
30 |
def __post_init__(self):
|
31 |
# Initialize lock only for file operations
|
32 |
-
self._storage_lock =
|
33 |
|
34 |
# Use global config value if specified, otherwise use default
|
35 |
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
@@ -45,32 +39,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
45 |
)
|
46 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
if is_multiprocess:
|
54 |
-
self._client.value = NanoVectorDB(
|
55 |
-
self.embedding_func.embedding_dim,
|
56 |
-
storage_file=self._client_file_name,
|
57 |
-
)
|
58 |
-
logger.info(
|
59 |
-
f"Initialized vector DB client for namespace {self.namespace}"
|
60 |
-
)
|
61 |
-
else:
|
62 |
-
self._client = NanoVectorDB(
|
63 |
-
self.embedding_func.embedding_dim,
|
64 |
-
storage_file=self._client_file_name,
|
65 |
-
)
|
66 |
-
logger.info(
|
67 |
-
f"Initialized vector DB client for namespace {self.namespace}"
|
68 |
-
)
|
69 |
|
70 |
def _get_client(self):
|
71 |
-
"""
|
72 |
-
if is_multiprocess:
|
73 |
-
return self._client.value
|
74 |
return self._client
|
75 |
|
76 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
@@ -101,8 +77,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
101 |
if len(embeddings) == len(list_data):
|
102 |
for i, d in enumerate(list_data):
|
103 |
d["__vector__"] = embeddings[i]
|
104 |
-
|
105 |
-
results = self._get_client().upsert(datas=list_data)
|
106 |
return results
|
107 |
else:
|
108 |
# sometimes the embedding is not returned correctly. just log it.
|
@@ -115,21 +90,20 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
115 |
embedding = await self.embedding_func([query])
|
116 |
embedding = embedding[0]
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
]
|
133 |
return results
|
134 |
|
135 |
@property
|
@@ -143,8 +117,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
143 |
ids: List of vector IDs to be deleted
|
144 |
"""
|
145 |
try:
|
146 |
-
|
147 |
-
self._get_client().delete(ids)
|
148 |
logger.debug(
|
149 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
150 |
)
|
@@ -158,37 +131,35 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
158 |
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
159 |
)
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
logger.debug(f"Entity {entity_name} not found in storage")
|
168 |
except Exception as e:
|
169 |
logger.error(f"Error deleting entity {entity_name}: {e}")
|
170 |
|
171 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
172 |
try:
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
logger.debug(
|
181 |
-
f"
|
182 |
)
|
183 |
-
|
184 |
-
|
185 |
-
if ids_to_delete:
|
186 |
-
self._get_client().delete(ids_to_delete)
|
187 |
-
logger.debug(
|
188 |
-
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
|
189 |
-
)
|
190 |
-
else:
|
191 |
-
logger.debug(f"No relations found for entity {entity_name}")
|
192 |
except Exception as e:
|
193 |
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
194 |
|
|
|
11 |
)
|
12 |
import pipmaster as pm
|
13 |
from lightrag.base import BaseVectorStorage
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
if not pm.is_installed("nano-vectordb"):
|
16 |
pm.install("nano-vectordb")
|
17 |
|
18 |
from nano_vectordb import NanoVectorDB
|
19 |
+
from threading import Lock as ThreadLock
|
20 |
|
21 |
@final
|
22 |
@dataclass
|
23 |
class NanoVectorDBStorage(BaseVectorStorage):
|
24 |
def __post_init__(self):
|
25 |
# Initialize lock only for file operations
|
26 |
+
self._storage_lock = ThreadLock()
|
27 |
|
28 |
# Use global config value if specified, otherwise use default
|
29 |
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
|
|
39 |
)
|
40 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
41 |
|
42 |
+
with self._storage_lock:
|
43 |
+
self._client = NanoVectorDB(
|
44 |
+
self.embedding_func.embedding_dim,
|
45 |
+
storage_file=self._client_file_name,
|
46 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
def _get_client(self):
|
49 |
+
"""Check if the shtorage should be reloaded"""
|
|
|
|
|
50 |
return self._client
|
51 |
|
52 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
|
77 |
if len(embeddings) == len(list_data):
|
78 |
for i, d in enumerate(list_data):
|
79 |
d["__vector__"] = embeddings[i]
|
80 |
+
results = self._get_client().upsert(datas=list_data)
|
|
|
81 |
return results
|
82 |
else:
|
83 |
# sometimes the embedding is not returned correctly. just log it.
|
|
|
90 |
embedding = await self.embedding_func([query])
|
91 |
embedding = embedding[0]
|
92 |
|
93 |
+
results = self._get_client().query(
|
94 |
+
query=embedding,
|
95 |
+
top_k=top_k,
|
96 |
+
better_than_threshold=self.cosine_better_than_threshold,
|
97 |
+
)
|
98 |
+
results = [
|
99 |
+
{
|
100 |
+
**dp,
|
101 |
+
"id": dp["__id__"],
|
102 |
+
"distance": dp["__metrics__"],
|
103 |
+
"created_at": dp.get("__created_at__"),
|
104 |
+
}
|
105 |
+
for dp in results
|
106 |
+
]
|
|
|
107 |
return results
|
108 |
|
109 |
@property
|
|
|
117 |
ids: List of vector IDs to be deleted
|
118 |
"""
|
119 |
try:
|
120 |
+
self._get_client().delete(ids)
|
|
|
121 |
logger.debug(
|
122 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
123 |
)
|
|
|
131 |
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
132 |
)
|
133 |
|
134 |
+
# Check if the entity exists
|
135 |
+
if self._get_client().get([entity_id]):
|
136 |
+
self._get_client().delete([entity_id])
|
137 |
+
logger.debug(f"Successfully deleted entity {entity_name}")
|
138 |
+
else:
|
139 |
+
logger.debug(f"Entity {entity_name} not found in storage")
|
|
|
140 |
except Exception as e:
|
141 |
logger.error(f"Error deleting entity {entity_name}: {e}")
|
142 |
|
143 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
144 |
try:
|
145 |
+
storage = getattr(self._get_client(), "_NanoVectorDB__storage")
|
146 |
+
relations = [
|
147 |
+
dp
|
148 |
+
for dp in storage["data"]
|
149 |
+
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
150 |
+
]
|
151 |
+
logger.debug(
|
152 |
+
f"Found {len(relations)} relations for entity {entity_name}"
|
153 |
+
)
|
154 |
+
ids_to_delete = [relation["__id__"] for relation in relations]
|
155 |
+
|
156 |
+
if ids_to_delete:
|
157 |
+
self._get_client().delete(ids_to_delete)
|
158 |
logger.debug(
|
159 |
+
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
|
160 |
)
|
161 |
+
else:
|
162 |
+
logger.debug(f"No relations found for entity {entity_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
except Exception as e:
|
164 |
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
165 |
|
lightrag/kg/networkx_impl.py
CHANGED
@@ -6,12 +6,6 @@ import numpy as np
|
|
6 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
7 |
from lightrag.utils import logger
|
8 |
from lightrag.base import BaseGraphStorage
|
9 |
-
from .shared_storage import (
|
10 |
-
get_storage_lock,
|
11 |
-
get_namespace_object,
|
12 |
-
is_multiprocess,
|
13 |
-
try_initialize_namespace,
|
14 |
-
)
|
15 |
|
16 |
import pipmaster as pm
|
17 |
|
@@ -23,7 +17,7 @@ if not pm.is_installed("graspologic"):
|
|
23 |
|
24 |
import networkx as nx
|
25 |
from graspologic import embed
|
26 |
-
|
27 |
|
28 |
@final
|
29 |
@dataclass
|
@@ -78,38 +72,23 @@ class NetworkXStorage(BaseGraphStorage):
|
|
78 |
self._graphml_xml_file = os.path.join(
|
79 |
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
80 |
)
|
81 |
-
self._storage_lock =
|
82 |
-
|
83 |
-
# check need_init must before get_namespace_object
|
84 |
-
need_init = try_initialize_namespace(self.namespace)
|
85 |
-
self._graph = get_namespace_object(self.namespace)
|
86 |
-
|
87 |
-
if need_init:
|
88 |
-
if is_multiprocess:
|
89 |
-
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
90 |
-
self._graph.value = preloaded_graph or nx.Graph()
|
91 |
-
if preloaded_graph:
|
92 |
-
logger.info(
|
93 |
-
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
94 |
-
)
|
95 |
-
else:
|
96 |
-
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
97 |
-
self._graph = preloaded_graph or nx.Graph()
|
98 |
-
if preloaded_graph:
|
99 |
-
logger.info(
|
100 |
-
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
101 |
-
)
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
logger.info("Created new empty graph")
|
104 |
-
|
105 |
self._node_embed_algorithms = {
|
106 |
"node2vec": self._node2vec_embed,
|
107 |
}
|
108 |
|
109 |
def _get_graph(self):
|
110 |
-
"""
|
111 |
-
if is_multiprocess:
|
112 |
-
return self._graph.value
|
113 |
return self._graph
|
114 |
|
115 |
async def index_done_callback(self) -> None:
|
@@ -117,54 +96,44 @@ class NetworkXStorage(BaseGraphStorage):
|
|
117 |
NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file)
|
118 |
|
119 |
async def has_node(self, node_id: str) -> bool:
|
120 |
-
|
121 |
-
return self._get_graph().has_node(node_id)
|
122 |
|
123 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
124 |
-
|
125 |
-
return self._get_graph().has_edge(source_node_id, target_node_id)
|
126 |
|
127 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
128 |
-
|
129 |
-
return self._get_graph().nodes.get(node_id)
|
130 |
|
131 |
async def node_degree(self, node_id: str) -> int:
|
132 |
-
|
133 |
-
return self._get_graph().degree(node_id)
|
134 |
|
135 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
136 |
-
|
137 |
-
return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id)
|
138 |
|
139 |
async def get_edge(
|
140 |
self, source_node_id: str, target_node_id: str
|
141 |
) -> dict[str, str] | None:
|
142 |
-
|
143 |
-
return self._get_graph().edges.get((source_node_id, target_node_id))
|
144 |
|
145 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
return None
|
150 |
|
151 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
152 |
-
|
153 |
-
self._get_graph().add_node(node_id, **node_data)
|
154 |
|
155 |
async def upsert_edge(
|
156 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
157 |
) -> None:
|
158 |
-
|
159 |
-
self._get_graph().add_edge(source_node_id, target_node_id, **edge_data)
|
160 |
|
161 |
async def delete_node(self, node_id: str) -> None:
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
168 |
|
169 |
async def embed_nodes(
|
170 |
self, algorithm: str
|
@@ -175,13 +144,12 @@ class NetworkXStorage(BaseGraphStorage):
|
|
175 |
|
176 |
# TODO: NOT USED
|
177 |
async def _node2vec_embed(self):
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
|
185 |
return embeddings, nodes_ids
|
186 |
|
187 |
def remove_nodes(self, nodes: list[str]):
|
@@ -190,11 +158,10 @@ class NetworkXStorage(BaseGraphStorage):
|
|
190 |
Args:
|
191 |
nodes: List of node IDs to be deleted
|
192 |
"""
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
graph.remove_node(node)
|
198 |
|
199 |
def remove_edges(self, edges: list[tuple[str, str]]):
|
200 |
"""Delete multiple edges
|
@@ -202,11 +169,10 @@ class NetworkXStorage(BaseGraphStorage):
|
|
202 |
Args:
|
203 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
204 |
"""
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
graph.remove_edge(source, target)
|
210 |
|
211 |
async def get_all_labels(self) -> list[str]:
|
212 |
"""
|
@@ -214,10 +180,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|
214 |
Returns:
|
215 |
[label1, label2, ...] # Alphabetically sorted label list
|
216 |
"""
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
labels.add(str(node)) # Add node id as a label
|
221 |
|
222 |
# Return sorted list
|
223 |
return sorted(list(labels))
|
@@ -239,88 +204,87 @@ class NetworkXStorage(BaseGraphStorage):
|
|
239 |
seen_nodes = set()
|
240 |
seen_edges = set()
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
)
|
279 |
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
)
|
301 |
)
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
)
|
323 |
-
|
|
|
324 |
|
325 |
logger.info(
|
326 |
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
|
|
6 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
7 |
from lightrag.utils import logger
|
8 |
from lightrag.base import BaseGraphStorage
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
import pipmaster as pm
|
11 |
|
|
|
17 |
|
18 |
import networkx as nx
|
19 |
from graspologic import embed
|
20 |
+
from threading import Lock as ThreadLock
|
21 |
|
22 |
@final
|
23 |
@dataclass
|
|
|
72 |
self._graphml_xml_file = os.path.join(
|
73 |
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
74 |
)
|
75 |
+
self._storage_lock = ThreadLock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
+
with self._storage_lock:
|
78 |
+
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
79 |
+
if preloaded_graph is not None:
|
80 |
+
logger.info(
|
81 |
+
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
82 |
+
)
|
83 |
+
else:
|
84 |
logger.info("Created new empty graph")
|
85 |
+
self._graph = preloaded_graph or nx.Graph()
|
86 |
self._node_embed_algorithms = {
|
87 |
"node2vec": self._node2vec_embed,
|
88 |
}
|
89 |
|
90 |
def _get_graph(self):
|
91 |
+
"""Check if the shtorage should be reloaded"""
|
|
|
|
|
92 |
return self._graph
|
93 |
|
94 |
async def index_done_callback(self) -> None:
|
|
|
96 |
NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file)
|
97 |
|
98 |
async def has_node(self, node_id: str) -> bool:
|
99 |
+
return self._get_graph().has_node(node_id)
|
|
|
100 |
|
101 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
102 |
+
return self._get_graph().has_edge(source_node_id, target_node_id)
|
|
|
103 |
|
104 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
105 |
+
return self._get_graph().nodes.get(node_id)
|
|
|
106 |
|
107 |
async def node_degree(self, node_id: str) -> int:
|
108 |
+
return self._get_graph().degree(node_id)
|
|
|
109 |
|
110 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
111 |
+
return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id)
|
|
|
112 |
|
113 |
async def get_edge(
|
114 |
self, source_node_id: str, target_node_id: str
|
115 |
) -> dict[str, str] | None:
|
116 |
+
return self._get_graph().edges.get((source_node_id, target_node_id))
|
|
|
117 |
|
118 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
119 |
+
if self._get_graph().has_node(source_node_id):
|
120 |
+
return list(self._get_graph().edges(source_node_id))
|
121 |
+
return None
|
|
|
122 |
|
123 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
124 |
+
self._get_graph().add_node(node_id, **node_data)
|
|
|
125 |
|
126 |
async def upsert_edge(
|
127 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
128 |
) -> None:
|
129 |
+
self._get_graph().add_edge(source_node_id, target_node_id, **edge_data)
|
|
|
130 |
|
131 |
async def delete_node(self, node_id: str) -> None:
|
132 |
+
if self._get_graph().has_node(node_id):
|
133 |
+
self._get_graph().remove_node(node_id)
|
134 |
+
logger.debug(f"Node {node_id} deleted from the graph.")
|
135 |
+
else:
|
136 |
+
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
|
|
137 |
|
138 |
async def embed_nodes(
|
139 |
self, algorithm: str
|
|
|
144 |
|
145 |
# TODO: NOT USED
|
146 |
async def _node2vec_embed(self):
|
147 |
+
graph = self._get_graph()
|
148 |
+
embeddings, nodes = embed.node2vec_embed(
|
149 |
+
graph,
|
150 |
+
**self.global_config["node2vec_params"],
|
151 |
+
)
|
152 |
+
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
|
|
|
153 |
return embeddings, nodes_ids
|
154 |
|
155 |
def remove_nodes(self, nodes: list[str]):
|
|
|
158 |
Args:
|
159 |
nodes: List of node IDs to be deleted
|
160 |
"""
|
161 |
+
graph = self._get_graph()
|
162 |
+
for node in nodes:
|
163 |
+
if graph.has_node(node):
|
164 |
+
graph.remove_node(node)
|
|
|
165 |
|
166 |
def remove_edges(self, edges: list[tuple[str, str]]):
|
167 |
"""Delete multiple edges
|
|
|
169 |
Args:
|
170 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
171 |
"""
|
172 |
+
graph = self._get_graph()
|
173 |
+
for source, target in edges:
|
174 |
+
if graph.has_edge(source, target):
|
175 |
+
graph.remove_edge(source, target)
|
|
|
176 |
|
177 |
async def get_all_labels(self) -> list[str]:
|
178 |
"""
|
|
|
180 |
Returns:
|
181 |
[label1, label2, ...] # Alphabetically sorted label list
|
182 |
"""
|
183 |
+
labels = set()
|
184 |
+
for node in self._get_graph().nodes():
|
185 |
+
labels.add(str(node)) # Add node id as a label
|
|
|
186 |
|
187 |
# Return sorted list
|
188 |
return sorted(list(labels))
|
|
|
204 |
seen_nodes = set()
|
205 |
seen_edges = set()
|
206 |
|
207 |
+
graph = self._get_graph()
|
208 |
+
|
209 |
+
# Handle special case for "*" label
|
210 |
+
if node_label == "*":
|
211 |
+
# For "*", return the entire graph including all nodes and edges
|
212 |
+
subgraph = (
|
213 |
+
graph.copy()
|
214 |
+
) # Create a copy to avoid modifying the original graph
|
215 |
+
else:
|
216 |
+
# Find nodes with matching node id (partial match)
|
217 |
+
nodes_to_explore = []
|
218 |
+
for n, attr in graph.nodes(data=True):
|
219 |
+
if node_label in str(n): # Use partial matching
|
220 |
+
nodes_to_explore.append(n)
|
221 |
+
|
222 |
+
if not nodes_to_explore:
|
223 |
+
logger.warning(f"No nodes found with label {node_label}")
|
224 |
+
return result
|
225 |
+
|
226 |
+
# Get subgraph using ego_graph
|
227 |
+
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
|
228 |
+
|
229 |
+
# Check if number of nodes exceeds max_graph_nodes
|
230 |
+
max_graph_nodes = 500
|
231 |
+
if len(subgraph.nodes()) > max_graph_nodes:
|
232 |
+
origin_nodes = len(subgraph.nodes())
|
233 |
+
node_degrees = dict(subgraph.degree())
|
234 |
+
top_nodes = sorted(
|
235 |
+
node_degrees.items(), key=lambda x: x[1], reverse=True
|
236 |
+
)[:max_graph_nodes]
|
237 |
+
top_node_ids = [node[0] for node in top_nodes]
|
238 |
+
# Create new subgraph with only top nodes
|
239 |
+
subgraph = subgraph.subgraph(top_node_ids)
|
240 |
+
logger.info(
|
241 |
+
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
|
242 |
+
)
|
|
|
243 |
|
244 |
+
# Add nodes to result
|
245 |
+
for node in subgraph.nodes():
|
246 |
+
if str(node) in seen_nodes:
|
247 |
+
continue
|
248 |
+
|
249 |
+
node_data = dict(subgraph.nodes[node])
|
250 |
+
# Get entity_type as labels
|
251 |
+
labels = []
|
252 |
+
if "entity_type" in node_data:
|
253 |
+
if isinstance(node_data["entity_type"], list):
|
254 |
+
labels.extend(node_data["entity_type"])
|
255 |
+
else:
|
256 |
+
labels.append(node_data["entity_type"])
|
257 |
+
|
258 |
+
# Create node with properties
|
259 |
+
node_properties = {k: v for k, v in node_data.items()}
|
260 |
+
|
261 |
+
result.nodes.append(
|
262 |
+
KnowledgeGraphNode(
|
263 |
+
id=str(node), labels=[str(node)], properties=node_properties
|
|
|
264 |
)
|
265 |
+
)
|
266 |
+
seen_nodes.add(str(node))
|
267 |
+
|
268 |
+
# Add edges to result
|
269 |
+
for edge in subgraph.edges():
|
270 |
+
source, target = edge
|
271 |
+
edge_id = f"{source}-{target}"
|
272 |
+
if edge_id in seen_edges:
|
273 |
+
continue
|
274 |
+
|
275 |
+
edge_data = dict(subgraph.edges[edge])
|
276 |
+
|
277 |
+
# Create edge with complete information
|
278 |
+
result.edges.append(
|
279 |
+
KnowledgeGraphEdge(
|
280 |
+
id=edge_id,
|
281 |
+
type="DIRECTED",
|
282 |
+
source=str(source),
|
283 |
+
target=str(target),
|
284 |
+
properties=edge_data,
|
285 |
)
|
286 |
+
)
|
287 |
+
seen_edges.add(edge_id)
|
288 |
|
289 |
logger.info(
|
290 |
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
lightrag/kg/shared_storage.py
CHANGED
@@ -20,15 +20,12 @@ LockType = Union[ProcessLock, ThreadLock]
|
|
20 |
_manager = None
|
21 |
_initialized = None
|
22 |
is_multiprocess = None
|
|
|
23 |
|
24 |
# shared data for storage across processes
|
25 |
_shared_dicts: Optional[Dict[str, Any]] = None
|
26 |
-
_share_objects: Optional[Dict[str, Any]] = None
|
27 |
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
28 |
|
29 |
-
_global_lock: Optional[LockType] = None
|
30 |
-
|
31 |
-
|
32 |
def initialize_share_data(workers: int = 1):
|
33 |
"""
|
34 |
Initialize shared storage data for single or multi-process mode.
|
@@ -53,7 +50,6 @@ def initialize_share_data(workers: int = 1):
|
|
53 |
is_multiprocess, \
|
54 |
_global_lock, \
|
55 |
_shared_dicts, \
|
56 |
-
_share_objects, \
|
57 |
_init_flags, \
|
58 |
_initialized
|
59 |
|
@@ -72,7 +68,6 @@ def initialize_share_data(workers: int = 1):
|
|
72 |
_global_lock = _manager.Lock()
|
73 |
# Create shared dictionaries with manager
|
74 |
_shared_dicts = _manager.dict()
|
75 |
-
_share_objects = _manager.dict()
|
76 |
_init_flags = (
|
77 |
_manager.dict()
|
78 |
) # Use shared dictionary to store initialization flags
|
@@ -83,7 +78,6 @@ def initialize_share_data(workers: int = 1):
|
|
83 |
is_multiprocess = False
|
84 |
_global_lock = ThreadLock()
|
85 |
_shared_dicts = {}
|
86 |
-
_share_objects = {}
|
87 |
_init_flags = {}
|
88 |
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
|
89 |
|
@@ -99,11 +93,7 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|
99 |
global _init_flags, _manager
|
100 |
|
101 |
if _init_flags is None:
|
102 |
-
|
103 |
-
f"Error: try to create nanmespace before Shared-Data is initialized, pid={os.getpid()}",
|
104 |
-
level="ERROR",
|
105 |
-
)
|
106 |
-
raise ValueError("Shared dictionaries not initialized")
|
107 |
|
108 |
if namespace not in _init_flags:
|
109 |
_init_flags[namespace] = True
|
@@ -113,43 +103,9 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|
113 |
return False
|
114 |
|
115 |
|
116 |
-
def _get_global_lock() -> LockType:
|
117 |
-
return _global_lock
|
118 |
-
|
119 |
-
|
120 |
def get_storage_lock() -> LockType:
|
121 |
"""return storage lock for data consistency"""
|
122 |
-
return
|
123 |
-
|
124 |
-
|
125 |
-
def get_scan_lock() -> LockType:
|
126 |
-
"""return scan_progress lock for data consistency"""
|
127 |
-
return get_storage_lock()
|
128 |
-
|
129 |
-
|
130 |
-
def get_namespace_object(namespace: str) -> Any:
|
131 |
-
"""Get an object for specific namespace"""
|
132 |
-
|
133 |
-
if _share_objects is None:
|
134 |
-
direct_log(
|
135 |
-
f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}",
|
136 |
-
level="ERROR",
|
137 |
-
)
|
138 |
-
raise ValueError("Shared dictionaries not initialized")
|
139 |
-
|
140 |
-
lock = _get_global_lock()
|
141 |
-
with lock:
|
142 |
-
if namespace not in _share_objects:
|
143 |
-
if namespace not in _share_objects:
|
144 |
-
if is_multiprocess:
|
145 |
-
_share_objects[namespace] = _manager.Value("O", None)
|
146 |
-
else:
|
147 |
-
_share_objects[namespace] = None
|
148 |
-
direct_log(
|
149 |
-
f"Created namespace: {namespace}(type={type(_share_objects[namespace])})"
|
150 |
-
)
|
151 |
-
|
152 |
-
return _share_objects[namespace]
|
153 |
|
154 |
|
155 |
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
@@ -161,7 +117,7 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
|
161 |
)
|
162 |
raise ValueError("Shared dictionaries not initialized")
|
163 |
|
164 |
-
lock =
|
165 |
with lock:
|
166 |
if namespace not in _shared_dicts:
|
167 |
if is_multiprocess and _manager is not None:
|
@@ -175,11 +131,6 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
|
175 |
return _shared_dicts[namespace]
|
176 |
|
177 |
|
178 |
-
def get_scan_progress() -> Dict[str, Any]:
|
179 |
-
"""get storage space for document scanning progress data"""
|
180 |
-
return get_namespace_data("scan_progress")
|
181 |
-
|
182 |
-
|
183 |
def finalize_share_data():
|
184 |
"""
|
185 |
Release shared resources and clean up.
|
@@ -195,7 +146,6 @@ def finalize_share_data():
|
|
195 |
is_multiprocess, \
|
196 |
_global_lock, \
|
197 |
_shared_dicts, \
|
198 |
-
_share_objects, \
|
199 |
_init_flags, \
|
200 |
_initialized
|
201 |
|
@@ -216,8 +166,6 @@ def finalize_share_data():
|
|
216 |
# Clear shared dictionaries first
|
217 |
if _shared_dicts is not None:
|
218 |
_shared_dicts.clear()
|
219 |
-
if _share_objects is not None:
|
220 |
-
_share_objects.clear()
|
221 |
if _init_flags is not None:
|
222 |
_init_flags.clear()
|
223 |
|
@@ -234,7 +182,6 @@ def finalize_share_data():
|
|
234 |
_initialized = None
|
235 |
is_multiprocess = None
|
236 |
_shared_dicts = None
|
237 |
-
_share_objects = None
|
238 |
_init_flags = None
|
239 |
_global_lock = None
|
240 |
|
|
|
20 |
_manager = None
|
21 |
_initialized = None
|
22 |
is_multiprocess = None
|
23 |
+
_global_lock: Optional[LockType] = None
|
24 |
|
25 |
# shared data for storage across processes
|
26 |
_shared_dicts: Optional[Dict[str, Any]] = None
|
|
|
27 |
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
28 |
|
|
|
|
|
|
|
29 |
def initialize_share_data(workers: int = 1):
|
30 |
"""
|
31 |
Initialize shared storage data for single or multi-process mode.
|
|
|
50 |
is_multiprocess, \
|
51 |
_global_lock, \
|
52 |
_shared_dicts, \
|
|
|
53 |
_init_flags, \
|
54 |
_initialized
|
55 |
|
|
|
68 |
_global_lock = _manager.Lock()
|
69 |
# Create shared dictionaries with manager
|
70 |
_shared_dicts = _manager.dict()
|
|
|
71 |
_init_flags = (
|
72 |
_manager.dict()
|
73 |
) # Use shared dictionary to store initialization flags
|
|
|
78 |
is_multiprocess = False
|
79 |
_global_lock = ThreadLock()
|
80 |
_shared_dicts = {}
|
|
|
81 |
_init_flags = {}
|
82 |
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
|
83 |
|
|
|
93 |
global _init_flags, _manager
|
94 |
|
95 |
if _init_flags is None:
|
96 |
+
raise ValueError("Try to create nanmespace before Shared-Data is initialized")
|
|
|
|
|
|
|
|
|
97 |
|
98 |
if namespace not in _init_flags:
|
99 |
_init_flags[namespace] = True
|
|
|
103 |
return False
|
104 |
|
105 |
|
|
|
|
|
|
|
|
|
106 |
def get_storage_lock() -> LockType:
|
107 |
"""return storage lock for data consistency"""
|
108 |
+
return _global_lock
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
|
111 |
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
|
|
117 |
)
|
118 |
raise ValueError("Shared dictionaries not initialized")
|
119 |
|
120 |
+
lock = get_storage_lock()
|
121 |
with lock:
|
122 |
if namespace not in _shared_dicts:
|
123 |
if is_multiprocess and _manager is not None:
|
|
|
131 |
return _shared_dicts[namespace]
|
132 |
|
133 |
|
|
|
|
|
|
|
|
|
|
|
134 |
def finalize_share_data():
|
135 |
"""
|
136 |
Release shared resources and clean up.
|
|
|
146 |
is_multiprocess, \
|
147 |
_global_lock, \
|
148 |
_shared_dicts, \
|
|
|
149 |
_init_flags, \
|
150 |
_initialized
|
151 |
|
|
|
166 |
# Clear shared dictionaries first
|
167 |
if _shared_dicts is not None:
|
168 |
_shared_dicts.clear()
|
|
|
|
|
169 |
if _init_flags is not None:
|
170 |
_init_flags.clear()
|
171 |
|
|
|
182 |
_initialized = None
|
183 |
is_multiprocess = None
|
184 |
_shared_dicts = None
|
|
|
185 |
_init_flags = None
|
186 |
_global_lock = None
|
187 |
|