YanSte commited on
Commit
e19b29b
·
1 Parent(s): df3dcd1

cleanup code

Browse files
lightrag/base.py CHANGED
@@ -91,7 +91,7 @@ class BaseKVStorage(StorageNameSpace):
91
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
92
  raise NotImplementedError
93
 
94
- async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
95
  raise NotImplementedError
96
 
97
  async def filter_keys(self, data: list[str]) -> set[str]:
@@ -103,10 +103,13 @@ class BaseKVStorage(StorageNameSpace):
103
 
104
  async def drop(self) -> None:
105
  raise NotImplementedError
106
-
107
- async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
108
  raise NotImplementedError
109
-
 
110
  @dataclass
111
  class BaseGraphStorage(StorageNameSpace):
112
  embedding_func: EmbeddingFunc = None
 
91
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
92
  raise NotImplementedError
93
 
94
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
95
  raise NotImplementedError
96
 
97
  async def filter_keys(self, data: list[str]) -> set[str]:
 
103
 
104
  async def drop(self) -> None:
105
  raise NotImplementedError
106
+
107
+ async def get_by_status_and_ids(
108
+ self, status: str
109
+ ) -> Union[list[dict[str, Any]], None]:
110
  raise NotImplementedError
111
+
112
+
113
  @dataclass
114
  class BaseGraphStorage(StorageNameSpace):
115
  embedding_func: EmbeddingFunc = None
lightrag/kg/json_kv_impl.py CHANGED
@@ -12,6 +12,7 @@ from lightrag.base import (
12
  BaseKVStorage,
13
  )
14
 
 
15
  @dataclass
16
  class JsonKVStorage(BaseKVStorage):
17
  def __post_init__(self):
@@ -30,7 +31,7 @@ class JsonKVStorage(BaseKVStorage):
30
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
31
  return self._data.get(id, None)
32
 
33
- async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
34
  return [
35
  (
36
  {k: v for k, v in self._data[id].items()}
@@ -50,6 +51,8 @@ class JsonKVStorage(BaseKVStorage):
50
  async def drop(self) -> None:
51
  self._data = {}
52
 
53
- async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
54
  result = [v for _, v in self._data.items() if v["status"] == status]
55
  return result if result else None
 
12
  BaseKVStorage,
13
  )
14
 
15
+
16
  @dataclass
17
  class JsonKVStorage(BaseKVStorage):
18
  def __post_init__(self):
 
31
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
32
  return self._data.get(id, None)
33
 
34
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
35
  return [
36
  (
37
  {k: v for k, v in self._data[id].items()}
 
51
  async def drop(self) -> None:
52
  self._data = {}
53
 
54
+ async def get_by_status_and_ids(
55
+ self, status: str
56
+ ) -> Union[list[dict[str, Any]], None]:
57
  result = [v for _, v in self._data.items() if v["status"] == status]
58
  return result if result else None
lightrag/kg/mongo_impl.py CHANGED
@@ -35,7 +35,7 @@ class MongoKVStorage(BaseKVStorage):
35
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
36
  return self._data.find_one({"_id": id})
37
 
38
- async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
39
  return list(self._data.find({"_id": {"$in": ids}}))
40
 
41
  async def filter_keys(self, data: list[str]) -> set[str]:
@@ -77,7 +77,9 @@ class MongoKVStorage(BaseKVStorage):
77
  """Drop the collection"""
78
  await self._data.drop()
79
 
80
- async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
81
  """Get documents by status and ids"""
82
  return self._data.find({"status": status})
83
 
 
35
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
36
  return self._data.find_one({"_id": id})
37
 
38
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
39
  return list(self._data.find({"_id": {"$in": ids}}))
40
 
41
  async def filter_keys(self, data: list[str]) -> set[str]:
 
77
  """Drop the collection"""
78
  await self._data.drop()
79
 
80
+ async def get_by_status_and_ids(
81
+ self, status: str
82
+ ) -> Union[list[dict[str, Any]], None]:
83
  """Get documents by status and ids"""
84
  return self._data.find({"status": status})
85
 
lightrag/kg/oracle_impl.py CHANGED
@@ -326,7 +326,8 @@ class OracleKVStorage(BaseKVStorage):
326
  (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
327
  ):
328
  logger.info("full doc and chunk data had been saved into oracle db!")
329
-
 
330
  @dataclass
331
  class OracleVectorDBStorage(BaseVectorStorage):
332
  # should pass db object to self.db
 
326
  (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
327
  ):
328
  logger.info("full doc and chunk data had been saved into oracle db!")
329
+
330
+
331
  @dataclass
332
  class OracleVectorDBStorage(BaseVectorStorage):
333
  # should pass db object to self.db
lightrag/kg/postgres_impl.py CHANGED
@@ -213,7 +213,7 @@ class PGKVStorage(BaseKVStorage):
213
  return None
214
 
215
  # Query by id
216
- async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
217
  """Get doc_chunks data by id"""
218
  sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
219
  ids=",".join([f"'{id}'" for id in ids])
@@ -237,12 +237,14 @@ class PGKVStorage(BaseKVStorage):
237
  return res
238
  else:
239
  return None
240
-
241
- async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
242
  """Specifically for llm_response_cache."""
243
  SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
244
  params = {"workspace": self.db.workspace, "status": status}
245
- return await self.db.query(SQL, params, multirows=True)
246
 
247
  async def all_keys(self) -> list[dict]:
248
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
 
213
  return None
214
 
215
  # Query by id
216
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
217
  """Get doc_chunks data by id"""
218
  sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
219
  ids=",".join([f"'{id}'" for id in ids])
 
237
  return res
238
  else:
239
  return None
240
+
241
+ async def get_by_status_and_ids(
242
+ self, status: str
243
+ ) -> Union[list[dict[str, Any]], None]:
244
  """Specifically for llm_response_cache."""
245
  SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
246
  params = {"workspace": self.db.workspace, "status": status}
247
+ return await self.db.query(SQL, params, multirows=True)
248
 
249
  async def all_keys(self) -> list[dict]:
250
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
lightrag/kg/redis_impl.py CHANGED
@@ -29,7 +29,7 @@ class RedisKVStorage(BaseKVStorage):
29
  data = await self._redis.get(f"{self.namespace}:{id}")
30
  return json.loads(data) if data else None
31
 
32
- async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
33
  pipe = self._redis.pipeline()
34
  for id in ids:
35
  pipe.get(f"{self.namespace}:{id}")
@@ -58,11 +58,12 @@ class RedisKVStorage(BaseKVStorage):
58
  keys = await self._redis.keys(f"{self.namespace}:*")
59
  if keys:
60
  await self._redis.delete(*keys)
61
-
62
- async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
63
  pipe = self._redis.pipeline()
64
  for key in await self._redis.keys(f"{self.namespace}:*"):
65
  pipe.hgetall(key)
66
  results = await pipe.execute()
67
  return [data for data in results if data.get("status") == status] or None
68
-
 
29
  data = await self._redis.get(f"{self.namespace}:{id}")
30
  return json.loads(data) if data else None
31
 
32
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
33
  pipe = self._redis.pipeline()
34
  for id in ids:
35
  pipe.get(f"{self.namespace}:{id}")
 
58
  keys = await self._redis.keys(f"{self.namespace}:*")
59
  if keys:
60
  await self._redis.delete(*keys)
61
+
62
+ async def get_by_status_and_ids(
63
+ self, status: str
64
+ ) -> Union[list[dict[str, Any]], None]:
65
  pipe = self._redis.pipeline()
66
  for key in await self._redis.keys(f"{self.namespace}:*"):
67
  pipe.hgetall(key)
68
  results = await pipe.execute()
69
  return [data for data in results if data.get("status") == status] or None
 
lightrag/kg/tidb_impl.py CHANGED
@@ -122,7 +122,7 @@ class TiDBKVStorage(BaseKVStorage):
122
  return None
123
 
124
  # Query by id
125
- async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
126
  """根据 id 获取 doc_chunks 数据"""
127
  SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
128
  ids=",".join([f"'{id}'" for id in ids])
@@ -333,10 +333,13 @@ class TiDBVectorDBStorage(BaseVectorStorage):
333
  merge_sql = SQL_TEMPLATES["insert_relationship"]
334
  await self.db.execute(merge_sql, data)
335
 
336
- async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
337
  SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
338
  params = {"workspace": self.db.workspace, "status": status}
339
- return await self.db.query(SQL, params, multirows=True)
 
340
 
341
  @dataclass
342
  class TiDBGraphStorage(BaseGraphStorage):
 
122
  return None
123
 
124
  # Query by id
125
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
126
  """根据 id 获取 doc_chunks 数据"""
127
  SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
128
  ids=",".join([f"'{id}'" for id in ids])
 
333
  merge_sql = SQL_TEMPLATES["insert_relationship"]
334
  await self.db.execute(merge_sql, data)
335
 
336
+ async def get_by_status_and_ids(
337
+ self, status: str
338
+ ) -> Union[list[dict[str, Any]], None]:
339
  SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
340
  params = {"workspace": self.db.workspace, "status": status}
341
+ return await self.db.query(SQL, params, multirows=True)
342
+
343
 
344
  @dataclass
345
  class TiDBGraphStorage(BaseGraphStorage):
lightrag/lightrag.py CHANGED
@@ -629,12 +629,7 @@ class LightRAG:
629
  # 4. Store original document
630
  for doc_id, doc in new_docs.items():
631
  await self.full_docs.upsert(
632
- {
633
- doc_id: {
634
- "content": doc["content"],
635
- "status": DocStatus.PENDING
636
- }
637
- }
638
  )
639
  logger.info(f"Stored {len(new_docs)} new unique documents")
640
 
@@ -642,10 +637,14 @@ class LightRAG:
642
  """Get pendding documents, split into chunks,insert chunks"""
643
  # 1. get all pending and failed documents
644
  _todo_doc_keys = []
645
-
646
- _failed_doc = await self.full_docs.get_by_status_and_ids(status=DocStatus.FAILED)
647
- _pendding_doc = await self.full_docs.get_by_status_and_ids(status=DocStatus.PENDING)
648
-
 
 
 
 
649
  if _failed_doc:
650
  _todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
651
  if _pendding_doc:
@@ -685,15 +684,19 @@ class LightRAG:
685
  )
686
  }
687
  chunk_cnt += len(chunks)
688
-
689
  try:
690
  # Store chunks in vector database
691
  await self.chunks_vdb.upsert(chunks)
692
  # Update doc status
693
- await self.text_chunks.upsert({**chunks, "status": DocStatus.PENDING})
 
 
694
  except Exception as e:
695
  # Mark as failed if any step fails
696
- await self.text_chunks.upsert({**chunks, "status": DocStatus.FAILED})
 
 
697
  raise e
698
  except Exception as e:
699
  import traceback
@@ -707,8 +710,12 @@ class LightRAG:
707
  """Get pendding or failed chunks, extract entities and relationships from each chunk"""
708
  # 1. get all pending and failed chunks
709
  _todo_chunk_keys = []
710
- _failed_chunks = await self.text_chunks.get_by_status_and_ids(status=DocStatus.FAILED)
711
- _pendding_chunks = await self.text_chunks.get_by_status_and_ids(status=DocStatus.PENDING)
 
 
 
 
712
  if _failed_chunks:
713
  _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
714
  if _pendding_chunks:
@@ -742,11 +749,15 @@ class LightRAG:
742
  if maybe_new_kg is None:
743
  logger.info("No entities or relationships extracted!")
744
  # Update status to processed
745
- await self.text_chunks.upsert({chunk_id: {"status": DocStatus.PROCESSED}})
 
 
746
  except Exception as e:
747
  logger.error("Failed to extract entities and relationships")
748
  # Mark as failed if any step fails
749
- await self.text_chunks.upsert({chunk_id: {"status": DocStatus.FAILED}})
 
 
750
  raise e
751
 
752
  with tqdm_async(
 
629
  # 4. Store original document
630
  for doc_id, doc in new_docs.items():
631
  await self.full_docs.upsert(
632
+ {doc_id: {"content": doc["content"], "status": DocStatus.PENDING}}
 
 
 
 
 
633
  )
634
  logger.info(f"Stored {len(new_docs)} new unique documents")
635
 
 
637
  """Get pendding documents, split into chunks,insert chunks"""
638
  # 1. get all pending and failed documents
639
  _todo_doc_keys = []
640
+
641
+ _failed_doc = await self.full_docs.get_by_status_and_ids(
642
+ status=DocStatus.FAILED
643
+ )
644
+ _pendding_doc = await self.full_docs.get_by_status_and_ids(
645
+ status=DocStatus.PENDING
646
+ )
647
+
648
  if _failed_doc:
649
  _todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
650
  if _pendding_doc:
 
684
  )
685
  }
686
  chunk_cnt += len(chunks)
687
+
688
  try:
689
  # Store chunks in vector database
690
  await self.chunks_vdb.upsert(chunks)
691
  # Update doc status
692
+ await self.text_chunks.upsert(
693
+ {**chunks, "status": DocStatus.PENDING}
694
+ )
695
  except Exception as e:
696
  # Mark as failed if any step fails
697
+ await self.text_chunks.upsert(
698
+ {**chunks, "status": DocStatus.FAILED}
699
+ )
700
  raise e
701
  except Exception as e:
702
  import traceback
 
710
  """Get pendding or failed chunks, extract entities and relationships from each chunk"""
711
  # 1. get all pending and failed chunks
712
  _todo_chunk_keys = []
713
+ _failed_chunks = await self.text_chunks.get_by_status_and_ids(
714
+ status=DocStatus.FAILED
715
+ )
716
+ _pendding_chunks = await self.text_chunks.get_by_status_and_ids(
717
+ status=DocStatus.PENDING
718
+ )
719
  if _failed_chunks:
720
  _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
721
  if _pendding_chunks:
 
749
  if maybe_new_kg is None:
750
  logger.info("No entities or relationships extracted!")
751
  # Update status to processed
752
+ await self.text_chunks.upsert(
753
+ {chunk_id: {"status": DocStatus.PROCESSED}}
754
+ )
755
  except Exception as e:
756
  logger.error("Failed to extract entities and relationships")
757
  # Mark as failed if any step fails
758
+ await self.text_chunks.upsert(
759
+ {chunk_id: {"status": DocStatus.FAILED}}
760
+ )
761
  raise e
762
 
763
  with tqdm_async(