yangdx
commited on
Commit
·
25287b8
1
Parent(s):
738c425
Fix linting
Browse files
lightrag/kg/faiss_impl.py
CHANGED
|
@@ -17,6 +17,7 @@ if not pm.is_installed("faiss"):
|
|
| 17 |
import faiss # type: ignore
|
| 18 |
from threading import Lock as ThreadLock
|
| 19 |
|
|
|
|
| 20 |
@final
|
| 21 |
@dataclass
|
| 22 |
class FaissVectorDBStorage(BaseVectorStorage):
|
|
@@ -59,7 +60,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
| 59 |
with self._storage_lock:
|
| 60 |
self._load_faiss_index()
|
| 61 |
|
| 62 |
-
|
| 63 |
def _get_index(self):
|
| 64 |
"""Check if the shtorage should be reloaded"""
|
| 65 |
return self._index
|
|
@@ -224,10 +224,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
| 224 |
logger.debug(f"Searching relations for entity {entity_name}")
|
| 225 |
relations = []
|
| 226 |
for fid, meta in self._id_to_meta.items():
|
| 227 |
-
if (
|
| 228 |
-
meta.get("src_id") == entity_name
|
| 229 |
-
or meta.get("tgt_id") == entity_name
|
| 230 |
-
):
|
| 231 |
relations.append(fid)
|
| 232 |
|
| 233 |
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
|
@@ -265,7 +262,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
| 265 |
new_id_to_meta[new_fid] = vec_meta
|
| 266 |
|
| 267 |
with self._storage_lock:
|
| 268 |
-
|
| 269 |
self._index = faiss.IndexFlatIP(self._dim)
|
| 270 |
if vectors_to_keep:
|
| 271 |
arr = np.array(vectors_to_keep, dtype=np.float32)
|
|
@@ -273,7 +270,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
| 273 |
|
| 274 |
self._id_to_meta = new_id_to_meta
|
| 275 |
|
| 276 |
-
|
| 277 |
def _save_faiss_index(self):
|
| 278 |
"""
|
| 279 |
Save the current Faiss index + metadata to disk so it can persist across runs.
|
|
@@ -290,7 +286,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
| 290 |
with open(self._meta_file, "w", encoding="utf-8") as f:
|
| 291 |
json.dump(serializable_dict, f)
|
| 292 |
|
| 293 |
-
|
| 294 |
def _load_faiss_index(self):
|
| 295 |
"""
|
| 296 |
Load the Faiss index + metadata from disk if it exists,
|
|
|
|
| 17 |
import faiss # type: ignore
|
| 18 |
from threading import Lock as ThreadLock
|
| 19 |
|
| 20 |
+
|
| 21 |
@final
|
| 22 |
@dataclass
|
| 23 |
class FaissVectorDBStorage(BaseVectorStorage):
|
|
|
|
| 60 |
with self._storage_lock:
|
| 61 |
self._load_faiss_index()
|
| 62 |
|
|
|
|
| 63 |
def _get_index(self):
|
| 64 |
"""Check if the shtorage should be reloaded"""
|
| 65 |
return self._index
|
|
|
|
| 224 |
logger.debug(f"Searching relations for entity {entity_name}")
|
| 225 |
relations = []
|
| 226 |
for fid, meta in self._id_to_meta.items():
|
| 227 |
+
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
|
|
|
|
|
|
|
|
|
|
| 228 |
relations.append(fid)
|
| 229 |
|
| 230 |
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
|
|
|
| 262 |
new_id_to_meta[new_fid] = vec_meta
|
| 263 |
|
| 264 |
with self._storage_lock:
|
| 265 |
+
# Re-init index
|
| 266 |
self._index = faiss.IndexFlatIP(self._dim)
|
| 267 |
if vectors_to_keep:
|
| 268 |
arr = np.array(vectors_to_keep, dtype=np.float32)
|
|
|
|
| 270 |
|
| 271 |
self._id_to_meta = new_id_to_meta
|
| 272 |
|
|
|
|
| 273 |
def _save_faiss_index(self):
|
| 274 |
"""
|
| 275 |
Save the current Faiss index + metadata to disk so it can persist across runs.
|
|
|
|
| 286 |
with open(self._meta_file, "w", encoding="utf-8") as f:
|
| 287 |
json.dump(serializable_dict, f)
|
| 288 |
|
|
|
|
| 289 |
def _load_faiss_index(self):
|
| 290 |
"""
|
| 291 |
Load the Faiss index + metadata from disk if it exists,
|
lightrag/kg/json_doc_status_impl.py
CHANGED
|
@@ -84,7 +84,9 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|
| 84 |
|
| 85 |
async def index_done_callback(self) -> None:
|
| 86 |
with self._storage_lock:
|
| 87 |
-
data_dict =
|
|
|
|
|
|
|
| 88 |
write_json(data_dict, self._file_name)
|
| 89 |
|
| 90 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
|
|
| 84 |
|
| 85 |
async def index_done_callback(self) -> None:
|
| 86 |
with self._storage_lock:
|
| 87 |
+
data_dict = (
|
| 88 |
+
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
| 89 |
+
)
|
| 90 |
write_json(data_dict, self._file_name)
|
| 91 |
|
| 92 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
lightrag/kg/json_kv_impl.py
CHANGED
|
@@ -36,7 +36,9 @@ class JsonKVStorage(BaseKVStorage):
|
|
| 36 |
|
| 37 |
async def index_done_callback(self) -> None:
|
| 38 |
with self._storage_lock:
|
| 39 |
-
data_dict =
|
|
|
|
|
|
|
| 40 |
write_json(data_dict, self._file_name)
|
| 41 |
|
| 42 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
|
|
|
| 36 |
|
| 37 |
async def index_done_callback(self) -> None:
|
| 38 |
with self._storage_lock:
|
| 39 |
+
data_dict = (
|
| 40 |
+
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
| 41 |
+
)
|
| 42 |
write_json(data_dict, self._file_name)
|
| 43 |
|
| 44 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
lightrag/kg/nano_vector_db_impl.py
CHANGED
|
@@ -18,6 +18,7 @@ if not pm.is_installed("nano-vectordb"):
|
|
| 18 |
from nano_vectordb import NanoVectorDB
|
| 19 |
from threading import Lock as ThreadLock
|
| 20 |
|
|
|
|
| 21 |
@final
|
| 22 |
@dataclass
|
| 23 |
class NanoVectorDBStorage(BaseVectorStorage):
|
|
@@ -148,9 +149,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
| 148 |
for dp in storage["data"]
|
| 149 |
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
| 150 |
]
|
| 151 |
-
logger.debug(
|
| 152 |
-
f"Found {len(relations)} relations for entity {entity_name}"
|
| 153 |
-
)
|
| 154 |
ids_to_delete = [relation["__id__"] for relation in relations]
|
| 155 |
|
| 156 |
if ids_to_delete:
|
|
|
|
| 18 |
from nano_vectordb import NanoVectorDB
|
| 19 |
from threading import Lock as ThreadLock
|
| 20 |
|
| 21 |
+
|
| 22 |
@final
|
| 23 |
@dataclass
|
| 24 |
class NanoVectorDBStorage(BaseVectorStorage):
|
|
|
|
| 149 |
for dp in storage["data"]
|
| 150 |
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
| 151 |
]
|
| 152 |
+
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
|
|
|
|
|
|
| 153 |
ids_to_delete = [relation["__id__"] for relation in relations]
|
| 154 |
|
| 155 |
if ids_to_delete:
|
lightrag/kg/networkx_impl.py
CHANGED
|
@@ -19,6 +19,7 @@ import networkx as nx
|
|
| 19 |
from graspologic import embed
|
| 20 |
from threading import Lock as ThreadLock
|
| 21 |
|
|
|
|
| 22 |
@final
|
| 23 |
@dataclass
|
| 24 |
class NetworkXStorage(BaseGraphStorage):
|
|
@@ -231,9 +232,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 231 |
if len(subgraph.nodes()) > max_graph_nodes:
|
| 232 |
origin_nodes = len(subgraph.nodes())
|
| 233 |
node_degrees = dict(subgraph.degree())
|
| 234 |
-
top_nodes = sorted(
|
| 235 |
-
|
| 236 |
-
|
| 237 |
top_node_ids = [node[0] for node in top_nodes]
|
| 238 |
# Create new subgraph with only top nodes
|
| 239 |
subgraph = subgraph.subgraph(top_node_ids)
|
|
|
|
| 19 |
from graspologic import embed
|
| 20 |
from threading import Lock as ThreadLock
|
| 21 |
|
| 22 |
+
|
| 23 |
@final
|
| 24 |
@dataclass
|
| 25 |
class NetworkXStorage(BaseGraphStorage):
|
|
|
|
| 232 |
if len(subgraph.nodes()) > max_graph_nodes:
|
| 233 |
origin_nodes = len(subgraph.nodes())
|
| 234 |
node_degrees = dict(subgraph.degree())
|
| 235 |
+
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
|
| 236 |
+
:max_graph_nodes
|
| 237 |
+
]
|
| 238 |
top_node_ids = [node[0] for node in top_nodes]
|
| 239 |
# Create new subgraph with only top nodes
|
| 240 |
subgraph = subgraph.subgraph(top_node_ids)
|
lightrag/kg/shared_storage.py
CHANGED
|
@@ -26,6 +26,7 @@ _global_lock: Optional[LockType] = None
|
|
| 26 |
_shared_dicts: Optional[Dict[str, Any]] = None
|
| 27 |
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
| 28 |
|
|
|
|
| 29 |
def initialize_share_data(workers: int = 1):
|
| 30 |
"""
|
| 31 |
Initialize shared storage data for single or multi-process mode.
|
|
@@ -66,9 +67,7 @@ def initialize_share_data(workers: int = 1):
|
|
| 66 |
is_multiprocess = True
|
| 67 |
_global_lock = _manager.Lock()
|
| 68 |
_shared_dicts = _manager.dict()
|
| 69 |
-
_init_flags = (
|
| 70 |
-
_manager.dict()
|
| 71 |
-
)
|
| 72 |
direct_log(
|
| 73 |
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
|
| 74 |
)
|
|
@@ -95,9 +94,13 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|
| 95 |
|
| 96 |
if namespace not in _init_flags:
|
| 97 |
_init_flags[namespace] = True
|
| 98 |
-
direct_log(
|
|
|
|
|
|
|
| 99 |
return True
|
| 100 |
-
direct_log(
|
|
|
|
|
|
|
| 101 |
return False
|
| 102 |
|
| 103 |
|
|
|
|
| 26 |
_shared_dicts: Optional[Dict[str, Any]] = None
|
| 27 |
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
| 28 |
|
| 29 |
+
|
| 30 |
def initialize_share_data(workers: int = 1):
|
| 31 |
"""
|
| 32 |
Initialize shared storage data for single or multi-process mode.
|
|
|
|
| 67 |
is_multiprocess = True
|
| 68 |
_global_lock = _manager.Lock()
|
| 69 |
_shared_dicts = _manager.dict()
|
| 70 |
+
_init_flags = _manager.dict()
|
|
|
|
|
|
|
| 71 |
direct_log(
|
| 72 |
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
|
| 73 |
)
|
|
|
|
| 94 |
|
| 95 |
if namespace not in _init_flags:
|
| 96 |
_init_flags[namespace] = True
|
| 97 |
+
direct_log(
|
| 98 |
+
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
|
| 99 |
+
)
|
| 100 |
return True
|
| 101 |
+
direct_log(
|
| 102 |
+
f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]"
|
| 103 |
+
)
|
| 104 |
return False
|
| 105 |
|
| 106 |
|