ParisNeo commited on
Commit
714e271
·
unverified ·
1 Parent(s): b31c2f1

Update storage.py

Browse files
Files changed (1) hide show
  1. lightrag/storage.py +1 -460
lightrag/storage.py CHANGED
@@ -1,460 +1 @@
1
- import asyncio
2
- import html
3
- import os
4
- from tqdm.asyncio import tqdm as tqdm_async
5
- from dataclasses import dataclass
6
- from typing import Any, Union, cast, Dict
7
- import networkx as nx
8
- import numpy as np
9
-
10
- from nano_vectordb import NanoVectorDB
11
- import time
12
-
13
- from .utils import (
14
- logger,
15
- load_json,
16
- write_json,
17
- compute_mdhash_id,
18
- )
19
-
20
- from .base import (
21
- BaseGraphStorage,
22
- BaseKVStorage,
23
- BaseVectorStorage,
24
- DocStatus,
25
- DocProcessingStatus,
26
- DocStatusStorage,
27
- )
28
-
29
-
30
- @dataclass
31
- class JsonKVStorage(BaseKVStorage):
32
- def __post_init__(self):
33
- working_dir = self.global_config["working_dir"]
34
- self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
35
- self._data = load_json(self._file_name) or {}
36
- self._lock = asyncio.Lock()
37
- logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
38
-
39
- async def all_keys(self) -> list[str]:
40
- return list(self._data.keys())
41
-
42
- async def index_done_callback(self):
43
- write_json(self._data, self._file_name)
44
-
45
- async def get_by_id(self, id):
46
- return self._data.get(id, None)
47
-
48
- async def get_by_ids(self, ids, fields=None):
49
- if fields is None:
50
- return [self._data.get(id, None) for id in ids]
51
- return [
52
- (
53
- {k: v for k, v in self._data[id].items() if k in fields}
54
- if self._data.get(id, None)
55
- else None
56
- )
57
- for id in ids
58
- ]
59
-
60
- async def filter_keys(self, data: list[str]) -> set[str]:
61
- return set([s for s in data if s not in self._data])
62
-
63
- async def upsert(self, data: dict[str, dict]):
64
- left_data = {k: v for k, v in data.items() if k not in self._data}
65
- self._data.update(left_data)
66
- return left_data
67
-
68
- async def drop(self):
69
- self._data = {}
70
-
71
- async def filter(self, filter_func):
72
- """Filter key-value pairs based on a filter function
73
-
74
- Args:
75
- filter_func: The filter function, which takes a value as an argument and returns a boolean value
76
-
77
- Returns:
78
- Dict: Key-value pairs that meet the condition
79
- """
80
- result = {}
81
- async with self._lock:
82
- for key, value in self._data.items():
83
- if filter_func(value):
84
- result[key] = value
85
- return result
86
-
87
- async def delete(self, ids: list[str]):
88
- """Delete data with specified IDs
89
-
90
- Args:
91
- ids: List of IDs to delete
92
- """
93
- async with self._lock:
94
- for id in ids:
95
- if id in self._data:
96
- del self._data[id]
97
- await self.index_done_callback()
98
- logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}")
99
-
100
-
101
- @dataclass
102
- class NanoVectorDBStorage(BaseVectorStorage):
103
- cosine_better_than_threshold: float = 0.2
104
-
105
- def __post_init__(self):
106
- self._client_file_name = os.path.join(
107
- self.global_config["working_dir"], f"vdb_{self.namespace}.json"
108
- )
109
- self._max_batch_size = self.global_config["embedding_batch_num"]
110
- self._client = NanoVectorDB(
111
- self.embedding_func.embedding_dim, storage_file=self._client_file_name
112
- )
113
- self.cosine_better_than_threshold = self.global_config.get(
114
- "cosine_better_than_threshold", self.cosine_better_than_threshold
115
- )
116
-
117
- async def upsert(self, data: dict[str, dict]):
118
- logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
119
- if not len(data):
120
- logger.warning("You insert an empty data to vector DB")
121
- return []
122
-
123
- current_time = time.time()
124
- list_data = [
125
- {
126
- "__id__": k,
127
- "__created_at__": current_time,
128
- **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
129
- }
130
- for k, v in data.items()
131
- ]
132
- contents = [v["content"] for v in data.values()]
133
- batches = [
134
- contents[i : i + self._max_batch_size]
135
- for i in range(0, len(contents), self._max_batch_size)
136
- ]
137
-
138
- async def wrapped_task(batch):
139
- result = await self.embedding_func(batch)
140
- pbar.update(1)
141
- return result
142
-
143
- embedding_tasks = [wrapped_task(batch) for batch in batches]
144
- pbar = tqdm_async(
145
- total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
146
- )
147
- embeddings_list = await asyncio.gather(*embedding_tasks)
148
-
149
- embeddings = np.concatenate(embeddings_list)
150
- if len(embeddings) == len(list_data):
151
- for i, d in enumerate(list_data):
152
- d["__vector__"] = embeddings[i]
153
- results = self._client.upsert(datas=list_data)
154
- return results
155
- else:
156
- # sometimes the embedding is not returned correctly. just log it.
157
- logger.error(
158
- f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
159
- )
160
-
161
- async def query(self, query: str, top_k=5):
162
- embedding = await self.embedding_func([query])
163
- embedding = embedding[0]
164
- results = self._client.query(
165
- query=embedding,
166
- top_k=top_k,
167
- better_than_threshold=self.cosine_better_than_threshold,
168
- )
169
- results = [
170
- {
171
- **dp,
172
- "id": dp["__id__"],
173
- "distance": dp["__metrics__"],
174
- "created_at": dp.get("__created_at__"),
175
- }
176
- for dp in results
177
- ]
178
- return results
179
-
180
- @property
181
- def client_storage(self):
182
- return getattr(self._client, "_NanoVectorDB__storage")
183
-
184
- async def delete(self, ids: list[str]):
185
- """Delete vectors with specified IDs
186
-
187
- Args:
188
- ids: List of vector IDs to be deleted
189
- """
190
- try:
191
- self._client.delete(ids)
192
- logger.info(
193
- f"Successfully deleted {len(ids)} vectors from {self.namespace}"
194
- )
195
- except Exception as e:
196
- logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
197
-
198
- async def delete_entity(self, entity_name: str):
199
- try:
200
- entity_id = compute_mdhash_id(entity_name, prefix="ent-")
201
- logger.debug(
202
- f"Attempting to delete entity {entity_name} with ID {entity_id}"
203
- )
204
- # Check if the entity exists
205
- if self._client.get([entity_id]):
206
- await self.delete([entity_id])
207
- logger.debug(f"Successfully deleted entity {entity_name}")
208
- else:
209
- logger.debug(f"Entity {entity_name} not found in storage")
210
- except Exception as e:
211
- logger.error(f"Error deleting entity {entity_name}: {e}")
212
-
213
- async def delete_entity_relation(self, entity_name: str):
214
- try:
215
- relations = [
216
- dp
217
- for dp in self.client_storage["data"]
218
- if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
219
- ]
220
- logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
221
- ids_to_delete = [relation["__id__"] for relation in relations]
222
-
223
- if ids_to_delete:
224
- await self.delete(ids_to_delete)
225
- logger.debug(
226
- f"Deleted {len(ids_to_delete)} relations for {entity_name}"
227
- )
228
- else:
229
- logger.debug(f"No relations found for entity {entity_name}")
230
- except Exception as e:
231
- logger.error(f"Error deleting relations for {entity_name}: {e}")
232
-
233
- async def index_done_callback(self):
234
- self._client.save()
235
-
236
-
237
- @dataclass
238
- class NetworkXStorage(BaseGraphStorage):
239
- @staticmethod
240
- def load_nx_graph(file_name) -> nx.Graph:
241
- if os.path.exists(file_name):
242
- return nx.read_graphml(file_name)
243
- return None
244
-
245
- @staticmethod
246
- def write_nx_graph(graph: nx.Graph, file_name):
247
- logger.info(
248
- f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
249
- )
250
- nx.write_graphml(graph, file_name)
251
-
252
- @staticmethod
253
- def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
254
- """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
255
- Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
256
- """
257
- from graspologic.utils import largest_connected_component
258
-
259
- graph = graph.copy()
260
- graph = cast(nx.Graph, largest_connected_component(graph))
261
- node_mapping = {
262
- node: html.unescape(node.upper().strip()) for node in graph.nodes()
263
- } # type: ignore
264
- graph = nx.relabel_nodes(graph, node_mapping)
265
- return NetworkXStorage._stabilize_graph(graph)
266
-
267
- @staticmethod
268
- def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
269
- """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
270
- Ensure an undirected graph with the same relationships will always be read the same way.
271
- """
272
- fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
273
-
274
- sorted_nodes = graph.nodes(data=True)
275
- sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
276
-
277
- fixed_graph.add_nodes_from(sorted_nodes)
278
- edges = list(graph.edges(data=True))
279
-
280
- if not graph.is_directed():
281
-
282
- def _sort_source_target(edge):
283
- source, target, edge_data = edge
284
- if source > target:
285
- temp = source
286
- source = target
287
- target = temp
288
- return source, target, edge_data
289
-
290
- edges = [_sort_source_target(edge) for edge in edges]
291
-
292
- def _get_edge_key(source: Any, target: Any) -> str:
293
- return f"{source} -> {target}"
294
-
295
- edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
296
-
297
- fixed_graph.add_edges_from(edges)
298
- return fixed_graph
299
-
300
- def __post_init__(self):
301
- self._graphml_xml_file = os.path.join(
302
- self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
303
- )
304
- preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
305
- if preloaded_graph is not None:
306
- logger.info(
307
- f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
308
- )
309
- self._graph = preloaded_graph or nx.Graph()
310
- self._node_embed_algorithms = {
311
- "node2vec": self._node2vec_embed,
312
- }
313
-
314
- async def index_done_callback(self):
315
- NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
316
-
317
- async def has_node(self, node_id: str) -> bool:
318
- return self._graph.has_node(node_id)
319
-
320
- async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
321
- return self._graph.has_edge(source_node_id, target_node_id)
322
-
323
- async def get_node(self, node_id: str) -> Union[dict, None]:
324
- return self._graph.nodes.get(node_id)
325
-
326
- async def node_degree(self, node_id: str) -> int:
327
- return self._graph.degree(node_id)
328
-
329
- async def edge_degree(self, src_id: str, tgt_id: str) -> int:
330
- return self._graph.degree(src_id) + self._graph.degree(tgt_id)
331
-
332
- async def get_edge(
333
- self, source_node_id: str, target_node_id: str
334
- ) -> Union[dict, None]:
335
- return self._graph.edges.get((source_node_id, target_node_id))
336
-
337
- async def get_node_edges(self, source_node_id: str):
338
- if self._graph.has_node(source_node_id):
339
- return list(self._graph.edges(source_node_id))
340
- return None
341
-
342
- async def upsert_node(self, node_id: str, node_data: dict[str, str]):
343
- self._graph.add_node(node_id, **node_data)
344
-
345
- async def upsert_edge(
346
- self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
347
- ):
348
- self._graph.add_edge(source_node_id, target_node_id, **edge_data)
349
-
350
- async def delete_node(self, node_id: str):
351
- """
352
- Delete a node from the graph based on the specified node_id.
353
-
354
- :param node_id: The node_id to delete
355
- """
356
- if self._graph.has_node(node_id):
357
- self._graph.remove_node(node_id)
358
- logger.info(f"Node {node_id} deleted from the graph.")
359
- else:
360
- logger.warning(f"Node {node_id} not found in the graph for deletion.")
361
-
362
- async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
363
- if algorithm not in self._node_embed_algorithms:
364
- raise ValueError(f"Node embedding algorithm {algorithm} not supported")
365
- return await self._node_embed_algorithms[algorithm]()
366
-
367
- # @TODO: NOT USED
368
- async def _node2vec_embed(self):
369
- from graspologic import embed
370
-
371
- embeddings, nodes = embed.node2vec_embed(
372
- self._graph,
373
- **self.global_config["node2vec_params"],
374
- )
375
-
376
- nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
377
- return embeddings, nodes_ids
378
-
379
- def remove_nodes(self, nodes: list[str]):
380
- """Delete multiple nodes
381
-
382
- Args:
383
- nodes: List of node IDs to be deleted
384
- """
385
- for node in nodes:
386
- if self._graph.has_node(node):
387
- self._graph.remove_node(node)
388
-
389
- def remove_edges(self, edges: list[tuple[str, str]]):
390
- """Delete multiple edges
391
-
392
- Args:
393
- edges: List of edges to be deleted, each edge is a (source, target) tuple
394
- """
395
- for source, target in edges:
396
- if self._graph.has_edge(source, target):
397
- self._graph.remove_edge(source, target)
398
-
399
-
400
- @dataclass
401
- class JsonDocStatusStorage(DocStatusStorage):
402
- """JSON implementation of document status storage"""
403
-
404
- def __post_init__(self):
405
- working_dir = self.global_config["working_dir"]
406
- self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
407
- self._data = load_json(self._file_name) or {}
408
- logger.info(f"Loaded document status storage with {len(self._data)} records")
409
-
410
- async def filter_keys(self, data: list[str]) -> set[str]:
411
- """Return keys that should be processed (not in storage or not successfully processed)"""
412
- return set(
413
- [
414
- k
415
- for k in data
416
- if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED
417
- ]
418
- )
419
-
420
- async def get_status_counts(self) -> Dict[str, int]:
421
- """Get counts of documents in each status"""
422
- counts = {status: 0 for status in DocStatus}
423
- for doc in self._data.values():
424
- counts[doc["status"]] += 1
425
- return counts
426
-
427
- async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
428
- """Get all failed documents"""
429
- return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
430
-
431
- async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
432
- """Get all pending documents"""
433
- return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
434
-
435
- async def index_done_callback(self):
436
- """Save data to file after indexing"""
437
- write_json(self._data, self._file_name)
438
-
439
- async def upsert(self, data: dict[str, dict]):
440
- """Update or insert document status
441
-
442
- Args:
443
- data: Dictionary of document IDs and their status data
444
- """
445
- self._data.update(data)
446
- await self.index_done_callback()
447
- return data
448
-
449
- async def get_by_id(self, id: str):
450
- return self._data.get(id)
451
-
452
- async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
453
- """Get document status by ID"""
454
- return self._data.get(doc_id)
455
-
456
- async def delete(self, doc_ids: list[str]):
457
- """Delete document status by IDs"""
458
- for doc_id in doc_ids:
459
- self._data.pop(doc_id, None)
460
- await self.index_done_callback()
 
1
+ # This file is not needed anymore (TODO: remove)