YanSte commited on
Commit
050eb99
1 Parent(s): c5d7abf

implemented method and cleaned the mess

Browse files
lightrag/kg/json_kv_impl.py CHANGED
@@ -1,53 +1,3 @@
1
- """
2
- JsonDocStatus Storage Module
3
- =======================
4
-
5
- This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
6
-
7
- The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
8
-
9
- Author: lightrag team
10
- Created: 2024-01-25
11
- License: MIT
12
-
13
- Permission is hereby granted, free of charge, to any person obtaining a copy
14
- of this software and associated documentation files (the "Software"), to deal
15
- in the Software without restriction, including without limitation the rights
16
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17
- copies of the Software, and to permit persons to whom the Software is
18
- furnished to do so, subject to the following conditions:
19
-
20
- The above copyright notice and this permission notice shall be included in all
21
- copies or substantial portions of the Software.
22
-
23
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29
- SOFTWARE.
30
-
31
- Version: 1.0.0
32
-
33
- Dependencies:
34
- - NetworkX
35
- - NumPy
36
- - LightRAG
37
- - graspologic
38
-
39
- Features:
40
- - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
41
- - Query graph nodes and edges
42
- - Calculate node and edge degrees
43
- - Embed nodes using various algorithms (e.g., Node2Vec)
44
- - Remove nodes and edges from the graph
45
-
46
- Usage:
47
- from lightrag.storage.networkx_storage import NetworkXStorage
48
-
49
- """
50
-
51
  import asyncio
52
  import os
53
  from dataclasses import dataclass
@@ -58,12 +8,10 @@ from lightrag.utils import (
58
  load_json,
59
  write_json,
60
  )
61
-
62
  from lightrag.base import (
63
  BaseKVStorage,
64
  )
65
 
66
-
67
  @dataclass
68
  class JsonKVStorage(BaseKVStorage):
69
  def __post_init__(self):
@@ -79,13 +27,13 @@ class JsonKVStorage(BaseKVStorage):
79
  async def index_done_callback(self):
80
  write_json(self._data, self._file_name)
81
 
82
- async def get_by_id(self, id: str):
83
  return self._data.get(id, None)
84
 
85
- async def get_by_ids(self, ids: list[str]):
86
  return [
87
  (
88
- {k: v for k, v in self._data[id].items() }
89
  if self._data.get(id, None)
90
  else None
91
  )
@@ -95,12 +43,11 @@ class JsonKVStorage(BaseKVStorage):
95
  async def filter_keys(self, data: list[str]) -> set[str]:
96
  return set([s for s in data if s not in self._data])
97
 
98
- async def upsert(self, data: dict[str, dict[str, Any]]):
99
  left_data = {k: v for k, v in data.items() if k not in self._data}
100
  self._data.update(left_data)
101
- return left_data
102
 
103
- async def drop(self):
104
  self._data = {}
105
 
106
  async def get_by_status_and_ids(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
  import os
3
  from dataclasses import dataclass
 
8
  load_json,
9
  write_json,
10
  )
 
11
  from lightrag.base import (
12
  BaseKVStorage,
13
  )
14
 
 
15
  @dataclass
16
  class JsonKVStorage(BaseKVStorage):
17
  def __post_init__(self):
 
27
  async def index_done_callback(self):
28
  write_json(self._data, self._file_name)
29
 
30
+ async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
31
  return self._data.get(id, None)
32
 
33
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
34
  return [
35
  (
36
+ {k: v for k, v in self._data[id].items()}
37
  if self._data.get(id, None)
38
  else None
39
  )
 
43
  async def filter_keys(self, data: list[str]) -> set[str]:
44
  return set([s for s in data if s not in self._data])
45
 
46
+ async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
47
  left_data = {k: v for k, v in data.items() if k not in self._data}
48
  self._data.update(left_data)
 
49
 
50
+ async def drop(self) -> None:
51
  self._data = {}
52
 
53
  async def get_by_status_and_ids(
lightrag/kg/jsondocstatus_impl.py CHANGED
@@ -50,7 +50,7 @@ Usage:
50
 
51
  import os
52
  from dataclasses import dataclass
53
- from typing import Union, Dict
54
 
55
  from lightrag.utils import (
56
  logger,
@@ -104,7 +104,7 @@ class JsonDocStatusStorage(DocStatusStorage):
104
  """Save data to file after indexing"""
105
  write_json(self._data, self._file_name)
106
 
107
- async def upsert(self, data: dict[str, dict]):
108
  """Update or insert document status
109
 
110
  Args:
@@ -114,7 +114,7 @@ class JsonDocStatusStorage(DocStatusStorage):
114
  await self.index_done_callback()
115
  return data
116
 
117
- async def get_by_id(self, id: str):
118
  return self._data.get(id)
119
 
120
  async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
 
50
 
51
  import os
52
  from dataclasses import dataclass
53
+ from typing import Any, Union, Dict
54
 
55
  from lightrag.utils import (
56
  logger,
 
104
  """Save data to file after indexing"""
105
  write_json(self._data, self._file_name)
106
 
107
+ async def upsert(self, data: dict[str, Any]) -> None:
108
  """Update or insert document status
109
 
110
  Args:
 
114
  await self.index_done_callback()
115
  return data
116
 
117
+ async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
118
  return self._data.get(id)
119
 
120
  async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
lightrag/kg/mongo_impl.py CHANGED
@@ -12,7 +12,7 @@ if not pm.is_installed("motor"):
12
 
13
  from pymongo import MongoClient
14
  from motor.motor_asyncio import AsyncIOMotorClient
15
- from typing import Union, List, Tuple
16
 
17
  from ..utils import logger
18
  from ..base import BaseKVStorage, BaseGraphStorage
@@ -32,18 +32,11 @@ class MongoKVStorage(BaseKVStorage):
32
  async def all_keys(self) -> list[str]:
33
  return [x["_id"] for x in self._data.find({}, {"_id": 1})]
34
 
35
- async def get_by_id(self, id):
36
  return self._data.find_one({"_id": id})
37
 
38
- async def get_by_ids(self, ids, fields=None):
39
- if fields is None:
40
- return list(self._data.find({"_id": {"$in": ids}}))
41
- return list(
42
- self._data.find(
43
- {"_id": {"$in": ids}},
44
- {field: 1 for field in fields},
45
- )
46
- )
47
 
48
  async def filter_keys(self, data: list[str]) -> set[str]:
49
  existing_ids = [
@@ -51,7 +44,7 @@ class MongoKVStorage(BaseKVStorage):
51
  ]
52
  return set([s for s in data if s not in existing_ids])
53
 
54
- async def upsert(self, data: dict[str, dict]):
55
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
56
  for mode, items in data.items():
57
  for k, v in tqdm_async(items.items(), desc="Upserting"):
@@ -66,7 +59,6 @@ class MongoKVStorage(BaseKVStorage):
66
  for k, v in tqdm_async(data.items(), desc="Upserting"):
67
  self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
68
  data[k]["_id"] = k
69
- return data
70
 
71
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
72
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
@@ -81,9 +73,15 @@ class MongoKVStorage(BaseKVStorage):
81
  else:
82
  return None
83
 
84
- async def drop(self):
85
- """ """
86
- pass
 
 
 
 
 
 
87
 
88
 
89
  @dataclass
 
12
 
13
  from pymongo import MongoClient
14
  from motor.motor_asyncio import AsyncIOMotorClient
15
+ from typing import Any, TypeVar, Union, List, Tuple
16
 
17
  from ..utils import logger
18
  from ..base import BaseKVStorage, BaseGraphStorage
 
32
  async def all_keys(self) -> list[str]:
33
  return [x["_id"] for x in self._data.find({}, {"_id": 1})]
34
 
35
+ async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
36
  return self._data.find_one({"_id": id})
37
 
38
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
39
+ return list(self._data.find({"_id": {"$in": ids}}))
 
 
 
 
 
 
 
40
 
41
  async def filter_keys(self, data: list[str]) -> set[str]:
42
  existing_ids = [
 
44
  ]
45
  return set([s for s in data if s not in existing_ids])
46
 
47
+ async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
48
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
49
  for mode, items in data.items():
50
  for k, v in tqdm_async(items.items(), desc="Upserting"):
 
59
  for k, v in tqdm_async(data.items(), desc="Upserting"):
60
  self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
61
  data[k]["_id"] = k
 
62
 
63
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
64
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
 
73
  else:
74
  return None
75
 
76
+ async def drop(self) -> None:
77
+ """Drop the collection"""
78
+ await self._data.drop()
79
+
80
+ async def get_by_status_and_ids(
81
+ self, status: str
82
+ ) -> Union[list[dict[str, Any]], None]:
83
+ """Get documents by status and ids"""
84
+ return self._data.find({"status": status})
85
 
86
 
87
  @dataclass
lightrag/kg/oracle_impl.py CHANGED
@@ -4,7 +4,7 @@ import asyncio
4
  # import html
5
  # import os
6
  from dataclasses import dataclass
7
- from typing import Union
8
  import numpy as np
9
  import array
10
  import pipmaster as pm
@@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage):
181
 
182
  ################ QUERY METHODS ################
183
 
184
- async def get_by_id(self, id: str) -> Union[dict, None]:
185
  """get doc_full data based on id."""
186
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
187
  params = {"workspace": self.db.workspace, "id": id}
@@ -211,7 +211,7 @@ class OracleKVStorage(BaseKVStorage):
211
  else:
212
  return None
213
 
214
- async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
215
  """get doc_chunks data based on id"""
216
  SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
217
  ids=",".join([f"'{id}'" for id in ids])
@@ -238,15 +238,10 @@ class OracleKVStorage(BaseKVStorage):
238
  return None
239
 
240
  async def get_by_status_and_ids(
241
- self, status: str, ids: list[str]
242
- ) -> Union[list[dict], None]:
243
  """Specifically for llm_response_cache."""
244
- if ids is not None:
245
- SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
246
- ids=",".join([f"'{id}'" for id in ids])
247
- )
248
- else:
249
- SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
250
  params = {"workspace": self.db.workspace, "status": status}
251
  res = await self.db.query(SQL, params, multirows=True)
252
  if res:
@@ -270,7 +265,7 @@ class OracleKVStorage(BaseKVStorage):
270
  return set(keys)
271
 
272
  ################ INSERT METHODS ################
273
- async def upsert(self, data: dict[str, dict]):
274
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
275
  list_data = [
276
  {
@@ -328,14 +323,6 @@ class OracleKVStorage(BaseKVStorage):
328
  }
329
 
330
  await self.db.execute(upsert_sql, _data)
331
- return None
332
-
333
- async def change_status(self, id: str, status: str):
334
- SQL = SQL_TEMPLATES["change_status"].format(
335
- table_name=namespace_to_table_name(self.namespace)
336
- )
337
- params = {"workspace": self.db.workspace, "id": id, "status": status}
338
- await self.db.execute(SQL, params)
339
 
340
  async def index_done_callback(self):
341
  if is_namespace(
@@ -343,8 +330,7 @@ class OracleKVStorage(BaseKVStorage):
343
  (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
344
  ):
345
  logger.info("full doc and chunk data had been saved into oracle db!")
346
-
347
-
348
  @dataclass
349
  class OracleVectorDBStorage(BaseVectorStorage):
350
  # should pass db object to self.db
@@ -745,7 +731,6 @@ SQL_TEMPLATES = {
745
  "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
746
  "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
747
  "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
748
- "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
749
  "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
750
  USING DUAL
751
  ON (a.id = :id and a.workspace = :workspace)
 
4
  # import html
5
  # import os
6
  from dataclasses import dataclass
7
+ from typing import Any, TypeVar, Union
8
  import numpy as np
9
  import array
10
  import pipmaster as pm
 
181
 
182
  ################ QUERY METHODS ################
183
 
184
+ async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
185
  """get doc_full data based on id."""
186
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
187
  params = {"workspace": self.db.workspace, "id": id}
 
211
  else:
212
  return None
213
 
214
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
215
  """get doc_chunks data based on id"""
216
  SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
217
  ids=",".join([f"'{id}'" for id in ids])
 
238
  return None
239
 
240
  async def get_by_status_and_ids(
241
+ self, status: str
242
+ ) -> Union[list[dict[str, Any]], None]:
243
  """Specifically for llm_response_cache."""
244
+ SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
 
 
 
 
 
245
  params = {"workspace": self.db.workspace, "status": status}
246
  res = await self.db.query(SQL, params, multirows=True)
247
  if res:
 
265
  return set(keys)
266
 
267
  ################ INSERT METHODS ################
268
+ async def upsert(self, data: dict[str, Any]) -> None:
269
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
270
  list_data = [
271
  {
 
323
  }
324
 
325
  await self.db.execute(upsert_sql, _data)
 
 
 
 
 
 
 
 
326
 
327
  async def index_done_callback(self):
328
  if is_namespace(
 
330
  (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
331
  ):
332
  logger.info("full doc and chunk data had been saved into oracle db!")
333
+
 
334
  @dataclass
335
  class OracleVectorDBStorage(BaseVectorStorage):
336
  # should pass db object to self.db
 
731
  "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
732
  "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
733
  "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
 
734
  "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
735
  USING DUAL
736
  ON (a.id = :id and a.workspace = :workspace)
lightrag/kg/postgres_impl.py CHANGED
@@ -30,7 +30,6 @@ from ..base import (
30
  DocStatus,
31
  DocProcessingStatus,
32
  BaseGraphStorage,
33
- T,
34
  )
35
  from ..namespace import NameSpace, is_namespace
36
 
@@ -184,7 +183,7 @@ class PGKVStorage(BaseKVStorage):
184
 
185
  ################ QUERY METHODS ################
186
 
187
- async def get_by_id(self, id: str) -> Union[dict, None]:
188
  """Get doc_full data by id."""
189
  sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
190
  params = {"workspace": self.db.workspace, "id": id}
@@ -214,7 +213,7 @@ class PGKVStorage(BaseKVStorage):
214
  return None
215
 
216
  # Query by id
217
- async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]:
218
  """Get doc_chunks data by id"""
219
  sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
220
  ids=",".join([f"'{id}'" for id in ids])
@@ -238,6 +237,14 @@ class PGKVStorage(BaseKVStorage):
238
  return res
239
  else:
240
  return None
 
 
 
 
 
 
 
 
241
 
242
  async def all_keys(self) -> list[dict]:
243
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
@@ -270,7 +277,7 @@ class PGKVStorage(BaseKVStorage):
270
  print(params)
271
 
272
  ################ INSERT METHODS ################
273
- async def upsert(self, data: Dict[str, dict]):
274
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
275
  pass
276
  elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
@@ -447,7 +454,7 @@ class PGDocStatusStorage(DocStatusStorage):
447
  existed = set([element["id"] for element in result])
448
  return set(data) - existed
449
 
450
- async def get_by_id(self, id: str) -> Union[T, None]:
451
  sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
452
  params = {"workspace": self.db.workspace, "id": id}
453
  result = await self.db.query(sql, params, True)
 
30
  DocStatus,
31
  DocProcessingStatus,
32
  BaseGraphStorage,
 
33
  )
34
  from ..namespace import NameSpace, is_namespace
35
 
 
183
 
184
  ################ QUERY METHODS ################
185
 
186
+ async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
187
  """Get doc_full data by id."""
188
  sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
189
  params = {"workspace": self.db.workspace, "id": id}
 
213
  return None
214
 
215
  # Query by id
216
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
217
  """Get doc_chunks data by id"""
218
  sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
219
  ids=",".join([f"'{id}'" for id in ids])
 
237
  return res
238
  else:
239
  return None
240
+
241
+ async def get_by_status_and_ids(
242
+ self, status: str
243
+ ) -> Union[list[dict[str, Any]], None]:
244
+ """Specifically for llm_response_cache."""
245
+ SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
246
+ params = {"workspace": self.db.workspace, "status": status}
247
+ return await self.db.query(SQL, params, multirows=True)
248
 
249
  async def all_keys(self) -> list[dict]:
250
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
 
277
  print(params)
278
 
279
  ################ INSERT METHODS ################
280
+ async def upsert(self, data: dict[str, Any]) -> None:
281
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
282
  pass
283
  elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
 
454
  existed = set([element["id"] for element in result])
455
  return set(data) - existed
456
 
457
+ async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
458
  sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
459
  params = {"workspace": self.db.workspace, "id": id}
460
  result = await self.db.query(sql, params, True)
lightrag/kg/redis_impl.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from tqdm.asyncio import tqdm as tqdm_async
3
  from dataclasses import dataclass
4
  import pipmaster as pm
@@ -28,21 +29,11 @@ class RedisKVStorage(BaseKVStorage):
28
  data = await self._redis.get(f"{self.namespace}:{id}")
29
  return json.loads(data) if data else None
30
 
31
- async def get_by_ids(self, ids, fields=None):
32
  pipe = self._redis.pipeline()
33
  for id in ids:
34
  pipe.get(f"{self.namespace}:{id}")
35
  results = await pipe.execute()
36
-
37
- if fields:
38
- # Filter fields if specified
39
- return [
40
- {field: value.get(field) for field in fields if field in value}
41
- if (value := json.loads(result))
42
- else None
43
- for result in results
44
- ]
45
-
46
  return [json.loads(result) if result else None for result in results]
47
 
48
  async def filter_keys(self, data: list[str]) -> set[str]:
@@ -54,7 +45,7 @@ class RedisKVStorage(BaseKVStorage):
54
  existing_ids = {data[i] for i, exists in enumerate(results) if exists}
55
  return set(data) - existing_ids
56
 
57
- async def upsert(self, data: dict[str, dict]):
58
  pipe = self._redis.pipeline()
59
  for k, v in tqdm_async(data.items(), desc="Upserting"):
60
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
@@ -62,9 +53,18 @@ class RedisKVStorage(BaseKVStorage):
62
 
63
  for k in data:
64
  data[k]["_id"] = k
65
- return data
66
 
67
- async def drop(self):
68
  keys = await self._redis.keys(f"{self.namespace}:*")
69
  if keys:
70
  await self._redis.delete(*keys)
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Any, TypeVar, Union
3
  from tqdm.asyncio import tqdm as tqdm_async
4
  from dataclasses import dataclass
5
  import pipmaster as pm
 
29
  data = await self._redis.get(f"{self.namespace}:{id}")
30
  return json.loads(data) if data else None
31
 
32
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
33
  pipe = self._redis.pipeline()
34
  for id in ids:
35
  pipe.get(f"{self.namespace}:{id}")
36
  results = await pipe.execute()
 
 
 
 
 
 
 
 
 
 
37
  return [json.loads(result) if result else None for result in results]
38
 
39
  async def filter_keys(self, data: list[str]) -> set[str]:
 
45
  existing_ids = {data[i] for i, exists in enumerate(results) if exists}
46
  return set(data) - existing_ids
47
 
48
+ async def upsert(self, data: dict[str, Any]) -> None:
49
  pipe = self._redis.pipeline()
50
  for k, v in tqdm_async(data.items(), desc="Upserting"):
51
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
 
53
 
54
  for k in data:
55
  data[k]["_id"] = k
 
56
 
57
+ async def drop(self) -> None:
58
  keys = await self._redis.keys(f"{self.namespace}:*")
59
  if keys:
60
  await self._redis.delete(*keys)
61
+
62
+ async def get_by_status_and_ids(
63
+ self, status: str,
64
+ ) -> Union[list[dict[str, Any]], None]:
65
+ pipe = self._redis.pipeline()
66
+ for key in await self._redis.keys(f"{self.namespace}:*"):
67
+ pipe.hgetall(key)
68
+ results = await pipe.execute()
69
+ return [data for data in results if data.get("status") == status] or None
70
+
lightrag/kg/tidb_impl.py CHANGED
@@ -1,7 +1,7 @@
1
  import asyncio
2
  import os
3
  from dataclasses import dataclass
4
- from typing import Union
5
 
6
  import numpy as np
7
  import pipmaster as pm
@@ -108,7 +108,7 @@ class TiDBKVStorage(BaseKVStorage):
108
 
109
  ################ QUERY METHODS ################
110
 
111
- async def get_by_id(self, id: str) -> Union[dict, None]:
112
  """鏍规嵁 id 鑾峰彇 doc_full 鏁版嵁."""
113
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
114
  params = {"id": id}
@@ -122,16 +122,14 @@ class TiDBKVStorage(BaseKVStorage):
122
  return None
123
 
124
  # Query by id
125
- async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
126
  """鏍规嵁 id 鑾峰彇 doc_chunks 鏁版嵁"""
127
  SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
128
  ids=",".join([f"'{id}'" for id in ids])
129
  )
130
- # print("get_by_ids:"+SQL)
131
  res = await self.db.query(SQL, multirows=True)
132
  if res:
133
  data = res # [{"data":i} for i in res]
134
- # print(data)
135
  return data
136
  else:
137
  return None
@@ -158,7 +156,7 @@ class TiDBKVStorage(BaseKVStorage):
158
  return data
159
 
160
  ################ INSERT full_doc AND chunks ################
161
- async def upsert(self, data: dict[str, dict]):
162
  left_data = {k: v for k, v in data.items() if k not in self._data}
163
  self._data.update(left_data)
164
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
@@ -335,6 +333,12 @@ class TiDBVectorDBStorage(BaseVectorStorage):
335
  merge_sql = SQL_TEMPLATES["insert_relationship"]
336
  await self.db.execute(merge_sql, data)
337
 
 
 
 
 
 
 
338
 
339
  @dataclass
340
  class TiDBGraphStorage(BaseGraphStorage):
 
1
  import asyncio
2
  import os
3
  from dataclasses import dataclass
4
+ from typing import Any, TypeVar, Union
5
 
6
  import numpy as np
7
  import pipmaster as pm
 
108
 
109
  ################ QUERY METHODS ################
110
 
111
+ async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
112
  """鏍规嵁 id 鑾峰彇 doc_full 鏁版嵁."""
113
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
114
  params = {"id": id}
 
122
  return None
123
 
124
  # Query by id
125
+ async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
126
  """鏍规嵁 id 鑾峰彇 doc_chunks 鏁版嵁"""
127
  SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
128
  ids=",".join([f"'{id}'" for id in ids])
129
  )
 
130
  res = await self.db.query(SQL, multirows=True)
131
  if res:
132
  data = res # [{"data":i} for i in res]
 
133
  return data
134
  else:
135
  return None
 
156
  return data
157
 
158
  ################ INSERT full_doc AND chunks ################
159
+ async def upsert(self, data: dict[str, Any]) -> None:
160
  left_data = {k: v for k, v in data.items() if k not in self._data}
161
  self._data.update(left_data)
162
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
 
333
  merge_sql = SQL_TEMPLATES["insert_relationship"]
334
  await self.db.execute(merge_sql, data)
335
 
336
+ async def get_by_status_and_ids(
337
+ self, status: str
338
+ ) -> Union[list[dict], None]:
339
+ SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
340
+ params = {"workspace": self.db.workspace, "status": status}
341
+ return await self.db.query(SQL, params, multirows=True)
342
 
343
  @dataclass
344
  class TiDBGraphStorage(BaseGraphStorage):