cleanup code
Browse files- lightrag/base.py +7 -4
- lightrag/kg/json_kv_impl.py +5 -2
- lightrag/kg/mongo_impl.py +4 -2
- lightrag/kg/oracle_impl.py +2 -1
- lightrag/kg/postgres_impl.py +6 -4
- lightrag/kg/redis_impl.py +5 -4
- lightrag/kg/tidb_impl.py +6 -3
- lightrag/lightrag.py +28 -17
lightrag/base.py
CHANGED
@@ -91,7 +91,7 @@ class BaseKVStorage(StorageNameSpace):
|
|
91 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
92 |
raise NotImplementedError
|
93 |
|
94 |
-
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
95 |
raise NotImplementedError
|
96 |
|
97 |
async def filter_keys(self, data: list[str]) -> set[str]:
|
@@ -103,10 +103,13 @@ class BaseKVStorage(StorageNameSpace):
|
|
103 |
|
104 |
async def drop(self) -> None:
|
105 |
raise NotImplementedError
|
106 |
-
|
107 |
-
async def get_by_status_and_ids(
|
|
|
|
|
108 |
raise NotImplementedError
|
109 |
-
|
|
|
110 |
@dataclass
|
111 |
class BaseGraphStorage(StorageNameSpace):
|
112 |
embedding_func: EmbeddingFunc = None
|
|
|
91 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
92 |
raise NotImplementedError
|
93 |
|
94 |
+
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
95 |
raise NotImplementedError
|
96 |
|
97 |
async def filter_keys(self, data: list[str]) -> set[str]:
|
|
|
103 |
|
104 |
async def drop(self) -> None:
|
105 |
raise NotImplementedError
|
106 |
+
|
107 |
+
async def get_by_status_and_ids(
|
108 |
+
self, status: str
|
109 |
+
) -> Union[list[dict[str, Any]], None]:
|
110 |
raise NotImplementedError
|
111 |
+
|
112 |
+
|
113 |
@dataclass
|
114 |
class BaseGraphStorage(StorageNameSpace):
|
115 |
embedding_func: EmbeddingFunc = None
|
lightrag/kg/json_kv_impl.py
CHANGED
@@ -12,6 +12,7 @@ from lightrag.base import (
|
|
12 |
BaseKVStorage,
|
13 |
)
|
14 |
|
|
|
15 |
@dataclass
|
16 |
class JsonKVStorage(BaseKVStorage):
|
17 |
def __post_init__(self):
|
@@ -30,7 +31,7 @@ class JsonKVStorage(BaseKVStorage):
|
|
30 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
31 |
return self._data.get(id, None)
|
32 |
|
33 |
-
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
34 |
return [
|
35 |
(
|
36 |
{k: v for k, v in self._data[id].items()}
|
@@ -50,6 +51,8 @@ class JsonKVStorage(BaseKVStorage):
|
|
50 |
async def drop(self) -> None:
|
51 |
self._data = {}
|
52 |
|
53 |
-
async def get_by_status_and_ids(
|
|
|
|
|
54 |
result = [v for _, v in self._data.items() if v["status"] == status]
|
55 |
return result if result else None
|
|
|
12 |
BaseKVStorage,
|
13 |
)
|
14 |
|
15 |
+
|
16 |
@dataclass
|
17 |
class JsonKVStorage(BaseKVStorage):
|
18 |
def __post_init__(self):
|
|
|
31 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
32 |
return self._data.get(id, None)
|
33 |
|
34 |
+
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
35 |
return [
|
36 |
(
|
37 |
{k: v for k, v in self._data[id].items()}
|
|
|
51 |
async def drop(self) -> None:
|
52 |
self._data = {}
|
53 |
|
54 |
+
async def get_by_status_and_ids(
|
55 |
+
self, status: str
|
56 |
+
) -> Union[list[dict[str, Any]], None]:
|
57 |
result = [v for _, v in self._data.items() if v["status"] == status]
|
58 |
return result if result else None
|
lightrag/kg/mongo_impl.py
CHANGED
@@ -35,7 +35,7 @@ class MongoKVStorage(BaseKVStorage):
|
|
35 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
36 |
return self._data.find_one({"_id": id})
|
37 |
|
38 |
-
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
39 |
return list(self._data.find({"_id": {"$in": ids}}))
|
40 |
|
41 |
async def filter_keys(self, data: list[str]) -> set[str]:
|
@@ -77,7 +77,9 @@ class MongoKVStorage(BaseKVStorage):
|
|
77 |
"""Drop the collection"""
|
78 |
await self._data.drop()
|
79 |
|
80 |
-
async def get_by_status_and_ids(
|
|
|
|
|
81 |
"""Get documents by status and ids"""
|
82 |
return self._data.find({"status": status})
|
83 |
|
|
|
35 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
36 |
return self._data.find_one({"_id": id})
|
37 |
|
38 |
+
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
39 |
return list(self._data.find({"_id": {"$in": ids}}))
|
40 |
|
41 |
async def filter_keys(self, data: list[str]) -> set[str]:
|
|
|
77 |
"""Drop the collection"""
|
78 |
await self._data.drop()
|
79 |
|
80 |
+
async def get_by_status_and_ids(
|
81 |
+
self, status: str
|
82 |
+
) -> Union[list[dict[str, Any]], None]:
|
83 |
"""Get documents by status and ids"""
|
84 |
return self._data.find({"status": status})
|
85 |
|
lightrag/kg/oracle_impl.py
CHANGED
@@ -326,7 +326,8 @@ class OracleKVStorage(BaseKVStorage):
|
|
326 |
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
327 |
):
|
328 |
logger.info("full doc and chunk data had been saved into oracle db!")
|
329 |
-
|
|
|
330 |
@dataclass
|
331 |
class OracleVectorDBStorage(BaseVectorStorage):
|
332 |
# should pass db object to self.db
|
|
|
326 |
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
327 |
):
|
328 |
logger.info("full doc and chunk data had been saved into oracle db!")
|
329 |
+
|
330 |
+
|
331 |
@dataclass
|
332 |
class OracleVectorDBStorage(BaseVectorStorage):
|
333 |
# should pass db object to self.db
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -213,7 +213,7 @@ class PGKVStorage(BaseKVStorage):
|
|
213 |
return None
|
214 |
|
215 |
# Query by id
|
216 |
-
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
217 |
"""Get doc_chunks data by id"""
|
218 |
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
219 |
ids=",".join([f"'{id}'" for id in ids])
|
@@ -237,12 +237,14 @@ class PGKVStorage(BaseKVStorage):
|
|
237 |
return res
|
238 |
else:
|
239 |
return None
|
240 |
-
|
241 |
-
async def get_by_status_and_ids(
|
|
|
|
|
242 |
"""Specifically for llm_response_cache."""
|
243 |
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
244 |
params = {"workspace": self.db.workspace, "status": status}
|
245 |
-
return await self.db.query(SQL, params, multirows=True)
|
246 |
|
247 |
async def all_keys(self) -> list[dict]:
|
248 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
|
|
213 |
return None
|
214 |
|
215 |
# Query by id
|
216 |
+
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
217 |
"""Get doc_chunks data by id"""
|
218 |
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
219 |
ids=",".join([f"'{id}'" for id in ids])
|
|
|
237 |
return res
|
238 |
else:
|
239 |
return None
|
240 |
+
|
241 |
+
async def get_by_status_and_ids(
|
242 |
+
self, status: str
|
243 |
+
) -> Union[list[dict[str, Any]], None]:
|
244 |
"""Specifically for llm_response_cache."""
|
245 |
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
246 |
params = {"workspace": self.db.workspace, "status": status}
|
247 |
+
return await self.db.query(SQL, params, multirows=True)
|
248 |
|
249 |
async def all_keys(self) -> list[dict]:
|
250 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
lightrag/kg/redis_impl.py
CHANGED
@@ -29,7 +29,7 @@ class RedisKVStorage(BaseKVStorage):
|
|
29 |
data = await self._redis.get(f"{self.namespace}:{id}")
|
30 |
return json.loads(data) if data else None
|
31 |
|
32 |
-
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
33 |
pipe = self._redis.pipeline()
|
34 |
for id in ids:
|
35 |
pipe.get(f"{self.namespace}:{id}")
|
@@ -58,11 +58,12 @@ class RedisKVStorage(BaseKVStorage):
|
|
58 |
keys = await self._redis.keys(f"{self.namespace}:*")
|
59 |
if keys:
|
60 |
await self._redis.delete(*keys)
|
61 |
-
|
62 |
-
async def get_by_status_and_ids(
|
|
|
|
|
63 |
pipe = self._redis.pipeline()
|
64 |
for key in await self._redis.keys(f"{self.namespace}:*"):
|
65 |
pipe.hgetall(key)
|
66 |
results = await pipe.execute()
|
67 |
return [data for data in results if data.get("status") == status] or None
|
68 |
-
|
|
|
29 |
data = await self._redis.get(f"{self.namespace}:{id}")
|
30 |
return json.loads(data) if data else None
|
31 |
|
32 |
+
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
33 |
pipe = self._redis.pipeline()
|
34 |
for id in ids:
|
35 |
pipe.get(f"{self.namespace}:{id}")
|
|
|
58 |
keys = await self._redis.keys(f"{self.namespace}:*")
|
59 |
if keys:
|
60 |
await self._redis.delete(*keys)
|
61 |
+
|
62 |
+
async def get_by_status_and_ids(
|
63 |
+
self, status: str
|
64 |
+
) -> Union[list[dict[str, Any]], None]:
|
65 |
pipe = self._redis.pipeline()
|
66 |
for key in await self._redis.keys(f"{self.namespace}:*"):
|
67 |
pipe.hgetall(key)
|
68 |
results = await pipe.execute()
|
69 |
return [data for data in results if data.get("status") == status] or None
|
|
lightrag/kg/tidb_impl.py
CHANGED
@@ -122,7 +122,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|
122 |
return None
|
123 |
|
124 |
# Query by id
|
125 |
-
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
126 |
"""根据 id 获取 doc_chunks 数据"""
|
127 |
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
128 |
ids=",".join([f"'{id}'" for id in ids])
|
@@ -333,10 +333,13 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
333 |
merge_sql = SQL_TEMPLATES["insert_relationship"]
|
334 |
await self.db.execute(merge_sql, data)
|
335 |
|
336 |
-
async def get_by_status_and_ids(
|
|
|
|
|
337 |
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
338 |
params = {"workspace": self.db.workspace, "status": status}
|
339 |
-
return await self.db.query(SQL, params, multirows=True)
|
|
|
340 |
|
341 |
@dataclass
|
342 |
class TiDBGraphStorage(BaseGraphStorage):
|
|
|
122 |
return None
|
123 |
|
124 |
# Query by id
|
125 |
+
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
126 |
"""根据 id 获取 doc_chunks 数据"""
|
127 |
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
128 |
ids=",".join([f"'{id}'" for id in ids])
|
|
|
333 |
merge_sql = SQL_TEMPLATES["insert_relationship"]
|
334 |
await self.db.execute(merge_sql, data)
|
335 |
|
336 |
+
async def get_by_status_and_ids(
|
337 |
+
self, status: str
|
338 |
+
) -> Union[list[dict[str, Any]], None]:
|
339 |
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
340 |
params = {"workspace": self.db.workspace, "status": status}
|
341 |
+
return await self.db.query(SQL, params, multirows=True)
|
342 |
+
|
343 |
|
344 |
@dataclass
|
345 |
class TiDBGraphStorage(BaseGraphStorage):
|
lightrag/lightrag.py
CHANGED
@@ -629,12 +629,7 @@ class LightRAG:
|
|
629 |
# 4. Store original document
|
630 |
for doc_id, doc in new_docs.items():
|
631 |
await self.full_docs.upsert(
|
632 |
-
{
|
633 |
-
doc_id: {
|
634 |
-
"content": doc["content"],
|
635 |
-
"status": DocStatus.PENDING
|
636 |
-
}
|
637 |
-
}
|
638 |
)
|
639 |
logger.info(f"Stored {len(new_docs)} new unique documents")
|
640 |
|
@@ -642,10 +637,14 @@ class LightRAG:
|
|
642 |
"""Get pendding documents, split into chunks,insert chunks"""
|
643 |
# 1. get all pending and failed documents
|
644 |
_todo_doc_keys = []
|
645 |
-
|
646 |
-
_failed_doc = await self.full_docs.get_by_status_and_ids(
|
647 |
-
|
648 |
-
|
|
|
|
|
|
|
|
|
649 |
if _failed_doc:
|
650 |
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
|
651 |
if _pendding_doc:
|
@@ -685,15 +684,19 @@ class LightRAG:
|
|
685 |
)
|
686 |
}
|
687 |
chunk_cnt += len(chunks)
|
688 |
-
|
689 |
try:
|
690 |
# Store chunks in vector database
|
691 |
await self.chunks_vdb.upsert(chunks)
|
692 |
# Update doc status
|
693 |
-
await self.text_chunks.upsert(
|
|
|
|
|
694 |
except Exception as e:
|
695 |
# Mark as failed if any step fails
|
696 |
-
await self.text_chunks.upsert(
|
|
|
|
|
697 |
raise e
|
698 |
except Exception as e:
|
699 |
import traceback
|
@@ -707,8 +710,12 @@ class LightRAG:
|
|
707 |
"""Get pendding or failed chunks, extract entities and relationships from each chunk"""
|
708 |
# 1. get all pending and failed chunks
|
709 |
_todo_chunk_keys = []
|
710 |
-
_failed_chunks = await self.text_chunks.get_by_status_and_ids(
|
711 |
-
|
|
|
|
|
|
|
|
|
712 |
if _failed_chunks:
|
713 |
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
|
714 |
if _pendding_chunks:
|
@@ -742,11 +749,15 @@ class LightRAG:
|
|
742 |
if maybe_new_kg is None:
|
743 |
logger.info("No entities or relationships extracted!")
|
744 |
# Update status to processed
|
745 |
-
await self.text_chunks.upsert(
|
|
|
|
|
746 |
except Exception as e:
|
747 |
logger.error("Failed to extract entities and relationships")
|
748 |
# Mark as failed if any step fails
|
749 |
-
await self.text_chunks.upsert(
|
|
|
|
|
750 |
raise e
|
751 |
|
752 |
with tqdm_async(
|
|
|
629 |
# 4. Store original document
|
630 |
for doc_id, doc in new_docs.items():
|
631 |
await self.full_docs.upsert(
|
632 |
+
{doc_id: {"content": doc["content"], "status": DocStatus.PENDING}}
|
|
|
|
|
|
|
|
|
|
|
633 |
)
|
634 |
logger.info(f"Stored {len(new_docs)} new unique documents")
|
635 |
|
|
|
637 |
"""Get pendding documents, split into chunks,insert chunks"""
|
638 |
# 1. get all pending and failed documents
|
639 |
_todo_doc_keys = []
|
640 |
+
|
641 |
+
_failed_doc = await self.full_docs.get_by_status_and_ids(
|
642 |
+
status=DocStatus.FAILED
|
643 |
+
)
|
644 |
+
_pendding_doc = await self.full_docs.get_by_status_and_ids(
|
645 |
+
status=DocStatus.PENDING
|
646 |
+
)
|
647 |
+
|
648 |
if _failed_doc:
|
649 |
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
|
650 |
if _pendding_doc:
|
|
|
684 |
)
|
685 |
}
|
686 |
chunk_cnt += len(chunks)
|
687 |
+
|
688 |
try:
|
689 |
# Store chunks in vector database
|
690 |
await self.chunks_vdb.upsert(chunks)
|
691 |
# Update doc status
|
692 |
+
await self.text_chunks.upsert(
|
693 |
+
{**chunks, "status": DocStatus.PENDING}
|
694 |
+
)
|
695 |
except Exception as e:
|
696 |
# Mark as failed if any step fails
|
697 |
+
await self.text_chunks.upsert(
|
698 |
+
{**chunks, "status": DocStatus.FAILED}
|
699 |
+
)
|
700 |
raise e
|
701 |
except Exception as e:
|
702 |
import traceback
|
|
|
710 |
"""Get pendding or failed chunks, extract entities and relationships from each chunk"""
|
711 |
# 1. get all pending and failed chunks
|
712 |
_todo_chunk_keys = []
|
713 |
+
_failed_chunks = await self.text_chunks.get_by_status_and_ids(
|
714 |
+
status=DocStatus.FAILED
|
715 |
+
)
|
716 |
+
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(
|
717 |
+
status=DocStatus.PENDING
|
718 |
+
)
|
719 |
if _failed_chunks:
|
720 |
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
|
721 |
if _pendding_chunks:
|
|
|
749 |
if maybe_new_kg is None:
|
750 |
logger.info("No entities or relationships extracted!")
|
751 |
# Update status to processed
|
752 |
+
await self.text_chunks.upsert(
|
753 |
+
{chunk_id: {"status": DocStatus.PROCESSED}}
|
754 |
+
)
|
755 |
except Exception as e:
|
756 |
logger.error("Failed to extract entities and relationships")
|
757 |
# Mark as failed if any step fails
|
758 |
+
await self.text_chunks.upsert(
|
759 |
+
{chunk_id: {"status": DocStatus.FAILED}}
|
760 |
+
)
|
761 |
raise e
|
762 |
|
763 |
with tqdm_async(
|