YanSte commited on
Commit
160abf9
·
1 Parent(s): 44241d0

updated clean of what implemented on BaseKVStorage

Browse files
lightrag/base.py CHANGED
@@ -121,11 +121,11 @@ class BaseKVStorage(StorageNameSpace):
121
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
122
  raise NotImplementedError
123
 
124
- async def filter_keys(self, data: set[str]) -> set[str]:
125
  """Return un-exist keys"""
126
  raise NotImplementedError
127
 
128
- async def upsert(self, data: dict[str, Any]) -> None:
129
  raise NotImplementedError
130
 
131
  async def drop(self) -> None:
 
121
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
122
  raise NotImplementedError
123
 
124
+ async def filter_keys(self, keys: set[str]) -> set[str]:
125
  """Return un-exist keys"""
126
  raise NotImplementedError
127
 
128
+ async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
129
  raise NotImplementedError
130
 
131
  async def drop(self) -> None:
lightrag/kg/json_kv_impl.py CHANGED
@@ -1,7 +1,7 @@
1
  import asyncio
2
  import os
3
  from dataclasses import dataclass
4
- from typing import Any, Union
5
 
6
  from lightrag.base import (
7
  BaseKVStorage,
@@ -25,7 +25,7 @@ class JsonKVStorage(BaseKVStorage):
25
  async def index_done_callback(self):
26
  write_json(self._data, self._file_name)
27
 
28
- async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
29
  return self._data.get(id)
30
 
31
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -38,7 +38,7 @@ class JsonKVStorage(BaseKVStorage):
38
  for id in ids
39
  ]
40
 
41
- async def filter_keys(self, data: set[str]) -> set[str]:
42
  return set(data) - set(self._data.keys())
43
 
44
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
1
  import asyncio
2
  import os
3
  from dataclasses import dataclass
4
+ from typing import Any
5
 
6
  from lightrag.base import (
7
  BaseKVStorage,
 
25
  async def index_done_callback(self):
26
  write_json(self._data, self._file_name)
27
 
28
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
29
  return self._data.get(id)
30
 
31
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
 
38
  for id in ids
39
  ]
40
 
41
+ async def filter_keys(self, keys: set[str]) -> set[str]:
42
  return set(data) - set(self._data.keys())
43
 
44
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
lightrag/kg/mongo_impl.py CHANGED
@@ -60,14 +60,14 @@ class MongoKVStorage(BaseKVStorage):
60
  # Ensure collection exists
61
  create_collection_if_not_exists(uri, database.name, self._collection_name)
62
 
63
- async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
64
  return await self._data.find_one({"_id": id})
65
 
66
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
67
  cursor = self._data.find({"_id": {"$in": ids}})
68
  return await cursor.to_list()
69
 
70
- async def filter_keys(self, data: set[str]) -> set[str]:
71
  cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
72
  existing_ids = {str(x["_id"]) async for x in cursor}
73
  return data - existing_ids
@@ -107,6 +107,9 @@ class MongoKVStorage(BaseKVStorage):
107
  else:
108
  return None
109
 
 
 
 
110
  async def drop(self) -> None:
111
  """Drop the collection"""
112
  await self._data.drop()
 
60
  # Ensure collection exists
61
  create_collection_if_not_exists(uri, database.name, self._collection_name)
62
 
63
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
64
  return await self._data.find_one({"_id": id})
65
 
66
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
67
  cursor = self._data.find({"_id": {"$in": ids}})
68
  return await cursor.to_list()
69
 
70
+ async def filter_keys(self, keys: set[str]) -> set[str]:
71
  cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
72
  existing_ids = {str(x["_id"]) async for x in cursor}
73
  return data - existing_ids
 
107
  else:
108
  return None
109
 
110
+ async def index_done_callback(self) -> None:
111
+ pass
112
+
113
  async def drop(self) -> None:
114
  """Drop the collection"""
115
  await self._data.drop()
lightrag/kg/oracle_impl.py CHANGED
@@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage):
181
 
182
  ################ QUERY METHODS ################
183
 
184
- async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
185
  """Get doc_full data based on id."""
186
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
187
  params = {"workspace": self.db.workspace, "id": id}
@@ -232,7 +232,7 @@ class OracleKVStorage(BaseKVStorage):
232
  res = [{k: v} for k, v in dict_res.items()]
233
  return res
234
 
235
- async def filter_keys(self, keys: list[str]) -> set[str]:
236
  """Return keys that don't exist in storage"""
237
  SQL = SQL_TEMPLATES["filter_keys"].format(
238
  table_name=namespace_to_table_name(self.namespace),
@@ -248,7 +248,7 @@ class OracleKVStorage(BaseKVStorage):
248
  return set(keys)
249
 
250
  ################ INSERT METHODS ################
251
- async def upsert(self, data: dict[str, Any]) -> None:
252
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
253
  list_data = [
254
  {
@@ -314,6 +314,8 @@ class OracleKVStorage(BaseKVStorage):
314
  ):
315
  logger.info("full doc and chunk data had been saved into oracle db!")
316
 
 
 
317
 
318
  @dataclass
319
  class OracleVectorDBStorage(BaseVectorStorage):
 
181
 
182
  ################ QUERY METHODS ################
183
 
184
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
185
  """Get doc_full data based on id."""
186
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
187
  params = {"workspace": self.db.workspace, "id": id}
 
232
  res = [{k: v} for k, v in dict_res.items()]
233
  return res
234
 
235
+ async def filter_keys(self, keys: set[str]) -> set[str]:
236
  """Return keys that don't exist in storage"""
237
  SQL = SQL_TEMPLATES["filter_keys"].format(
238
  table_name=namespace_to_table_name(self.namespace),
 
248
  return set(keys)
249
 
250
  ################ INSERT METHODS ################
251
+ async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
252
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
253
  list_data = [
254
  {
 
314
  ):
315
  logger.info("full doc and chunk data had been saved into oracle db!")
316
 
317
+ async def drop(self) -> None:
318
+ raise NotImplementedError
319
 
320
  @dataclass
321
  class OracleVectorDBStorage(BaseVectorStorage):
lightrag/kg/postgres_impl.py CHANGED
@@ -4,7 +4,7 @@ import json
4
  import os
5
  import time
6
  from dataclasses import dataclass
7
- from typing import Any, Dict, List, Set, Tuple, Union
8
 
9
  import numpy as np
10
  import pipmaster as pm
@@ -185,7 +185,7 @@ class PGKVStorage(BaseKVStorage):
185
 
186
  ################ QUERY METHODS ################
187
 
188
- async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
189
  """Get doc_full data by id."""
190
  sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
191
  params = {"workspace": self.db.workspace, "id": id}
@@ -240,7 +240,7 @@ class PGKVStorage(BaseKVStorage):
240
  params = {"workspace": self.db.workspace, "status": status}
241
  return await self.db.query(SQL, params, multirows=True)
242
 
243
- async def filter_keys(self, keys: List[str]) -> Set[str]:
244
  """Filter out duplicated content"""
245
  sql = SQL_TEMPLATES["filter_keys"].format(
246
  table_name=namespace_to_table_name(self.namespace),
@@ -261,7 +261,7 @@ class PGKVStorage(BaseKVStorage):
261
  print(params)
262
 
263
  ################ INSERT METHODS ################
264
- async def upsert(self, data: dict[str, Any]) -> None:
265
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
266
  pass
267
  elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
@@ -294,6 +294,8 @@ class PGKVStorage(BaseKVStorage):
294
  ):
295
  logger.info("full doc and chunk data had been saved into postgresql db!")
296
 
 
 
297
 
298
  @dataclass
299
  class PGVectorStorage(BaseVectorStorage):
 
4
  import os
5
  import time
6
  from dataclasses import dataclass
7
+ from typing import Any, Dict, List, Tuple, Union
8
 
9
  import numpy as np
10
  import pipmaster as pm
 
185
 
186
  ################ QUERY METHODS ################
187
 
188
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
189
  """Get doc_full data by id."""
190
  sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
191
  params = {"workspace": self.db.workspace, "id": id}
 
240
  params = {"workspace": self.db.workspace, "status": status}
241
  return await self.db.query(SQL, params, multirows=True)
242
 
243
+ async def filter_keys(self, keys: set[str]) -> set[str]:
244
  """Filter out duplicated content"""
245
  sql = SQL_TEMPLATES["filter_keys"].format(
246
  table_name=namespace_to_table_name(self.namespace),
 
261
  print(params)
262
 
263
  ################ INSERT METHODS ################
264
+ async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
265
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
266
  pass
267
  elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
 
294
  ):
295
  logger.info("full doc and chunk data had been saved into postgresql db!")
296
 
297
+ async def drop(self) -> None:
298
+ raise NotImplementedError
299
 
300
  @dataclass
301
  class PGVectorStorage(BaseVectorStorage):
lightrag/kg/redis_impl.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Any, Union
3
  from tqdm.asyncio import tqdm as tqdm_async
4
  from dataclasses import dataclass
5
  import pipmaster as pm
@@ -28,7 +28,7 @@ class RedisKVStorage(BaseKVStorage):
28
  self._redis = Redis.from_url(redis_url, decode_responses=True)
29
  logger.info(f"Use Redis as KV {self.namespace}")
30
 
31
- async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
32
  data = await self._redis.get(f"{self.namespace}:{id}")
33
  return json.loads(data) if data else None
34
 
@@ -39,7 +39,7 @@ class RedisKVStorage(BaseKVStorage):
39
  results = await pipe.execute()
40
  return [json.loads(result) if result else None for result in results]
41
 
42
- async def filter_keys(self, data: set[str]) -> set[str]:
43
  pipe = self._redis.pipeline()
44
  for key in data:
45
  pipe.exists(f"{self.namespace}:{key}")
@@ -48,7 +48,7 @@ class RedisKVStorage(BaseKVStorage):
48
  existing_ids = {data[i] for i, exists in enumerate(results) if exists}
49
  return set(data) - existing_ids
50
 
51
- async def upsert(self, data: dict[str, Any]) -> None:
52
  pipe = self._redis.pipeline()
53
  for k, v in tqdm_async(data.items(), desc="Upserting"):
54
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
@@ -61,3 +61,6 @@ class RedisKVStorage(BaseKVStorage):
61
  keys = await self._redis.keys(f"{self.namespace}:*")
62
  if keys:
63
  await self._redis.delete(*keys)
 
 
 
 
1
  import os
2
+ from typing import Any
3
  from tqdm.asyncio import tqdm as tqdm_async
4
  from dataclasses import dataclass
5
  import pipmaster as pm
 
28
  self._redis = Redis.from_url(redis_url, decode_responses=True)
29
  logger.info(f"Use Redis as KV {self.namespace}")
30
 
31
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
32
  data = await self._redis.get(f"{self.namespace}:{id}")
33
  return json.loads(data) if data else None
34
 
 
39
  results = await pipe.execute()
40
  return [json.loads(result) if result else None for result in results]
41
 
42
+ async def filter_keys(self, keys: set[str]) -> set[str]:
43
  pipe = self._redis.pipeline()
44
  for key in data:
45
  pipe.exists(f"{self.namespace}:{key}")
 
48
  existing_ids = {data[i] for i, exists in enumerate(results) if exists}
49
  return set(data) - existing_ids
50
 
51
+ async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
52
  pipe = self._redis.pipeline()
53
  for k, v in tqdm_async(data.items(), desc="Upserting"):
54
  pipe.set(f"{self.namespace}:{k}", json.dumps(v))
 
61
  keys = await self._redis.keys(f"{self.namespace}:*")
62
  if keys:
63
  await self._redis.delete(*keys)
64
+
65
+ async def index_done_callback(self) -> None:
66
+ pass
lightrag/kg/tidb_impl.py CHANGED
@@ -110,7 +110,7 @@ class TiDBKVStorage(BaseKVStorage):
110
 
111
  ################ QUERY METHODS ################
112
 
113
- async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
114
  """Fetch doc_full data by id."""
115
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
116
  params = {"id": id}
@@ -125,7 +125,7 @@ class TiDBKVStorage(BaseKVStorage):
125
  )
126
  return await self.db.query(SQL, multirows=True)
127
 
128
- async def filter_keys(self, keys: list[str]) -> set[str]:
129
  """过滤掉重复内容"""
130
  SQL = SQL_TEMPLATES["filter_keys"].format(
131
  table_name=namespace_to_table_name(self.namespace),
@@ -147,7 +147,7 @@ class TiDBKVStorage(BaseKVStorage):
147
  return data
148
 
149
  ################ INSERT full_doc AND chunks ################
150
- async def upsert(self, data: dict[str, Any]) -> None:
151
  left_data = {k: v for k, v in data.items() if k not in self._data}
152
  self._data.update(left_data)
153
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
@@ -207,6 +207,8 @@ class TiDBKVStorage(BaseKVStorage):
207
  ):
208
  logger.info("full doc and chunk data had been saved into TiDB db!")
209
 
 
 
210
 
211
  @dataclass
212
  class TiDBVectorDBStorage(BaseVectorStorage):
 
110
 
111
  ################ QUERY METHODS ################
112
 
113
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
114
  """Fetch doc_full data by id."""
115
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
116
  params = {"id": id}
 
125
  )
126
  return await self.db.query(SQL, multirows=True)
127
 
128
+ async def filter_keys(self, keys: set[str]) -> set[str]:
129
  """过滤掉重复内容"""
130
  SQL = SQL_TEMPLATES["filter_keys"].format(
131
  table_name=namespace_to_table_name(self.namespace),
 
147
  return data
148
 
149
  ################ INSERT full_doc AND chunks ################
150
+ async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
151
  left_data = {k: v for k, v in data.items() if k not in self._data}
152
  self._data.update(left_data)
153
  if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
 
207
  ):
208
  logger.info("full doc and chunk data had been saved into TiDB db!")
209
 
210
+ async def drop(self) -> None:
211
+ raise NotImplementedError
212
 
213
  @dataclass
214
  class TiDBVectorDBStorage(BaseVectorStorage):