cleaned code
Browse files- lightrag/base.py +15 -13
- lightrag/kg/json_kv_impl.py +8 -8
- lightrag/kg/jsondocstatus_impl.py +1 -1
- lightrag/kg/mongo_impl.py +9 -8
- lightrag/kg/oracle_impl.py +14 -10
- lightrag/kg/postgres_impl.py +16 -14
- lightrag/kg/redis_impl.py +2 -2
- lightrag/kg/tidb_impl.py +10 -11
lightrag/base.py
CHANGED
@@ -1,24 +1,26 @@
|
|
1 |
-
from enum import Enum
|
2 |
import os
|
3 |
from dataclasses import dataclass, field
|
|
|
4 |
from typing import (
|
|
|
|
|
5 |
Optional,
|
6 |
TypedDict,
|
7 |
-
Union,
|
8 |
-
Literal,
|
9 |
TypeVar,
|
10 |
-
|
11 |
)
|
12 |
|
13 |
import numpy as np
|
14 |
|
15 |
-
|
16 |
from .utils import EmbeddingFunc
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
22 |
|
23 |
T = TypeVar("T")
|
24 |
|
@@ -57,11 +59,11 @@ class StorageNameSpace:
|
|
57 |
global_config: dict[str, Any]
|
58 |
|
59 |
async def index_done_callback(self):
|
60 |
-
"""
|
61 |
pass
|
62 |
|
63 |
async def query_done_callback(self):
|
64 |
-
"""
|
65 |
pass
|
66 |
|
67 |
|
@@ -84,14 +86,14 @@ class BaseVectorStorage(StorageNameSpace):
|
|
84 |
class BaseKVStorage(StorageNameSpace):
|
85 |
embedding_func: EmbeddingFunc
|
86 |
|
87 |
-
async def get_by_id(self, id: str) -> dict[str, Any]:
|
88 |
raise NotImplementedError
|
89 |
|
90 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
91 |
raise NotImplementedError
|
92 |
|
93 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
94 |
-
"""
|
95 |
raise NotImplementedError
|
96 |
|
97 |
async def upsert(self, data: dict[str, Any]) -> None:
|
|
|
|
|
1 |
import os
|
2 |
from dataclasses import dataclass, field
|
3 |
+
from enum import Enum
|
4 |
from typing import (
|
5 |
+
Any,
|
6 |
+
Literal,
|
7 |
Optional,
|
8 |
TypedDict,
|
|
|
|
|
9 |
TypeVar,
|
10 |
+
Union,
|
11 |
)
|
12 |
|
13 |
import numpy as np
|
14 |
|
|
|
15 |
from .utils import EmbeddingFunc
|
16 |
|
17 |
+
|
18 |
+
class TextChunkSchema(TypedDict):
|
19 |
+
tokens: int
|
20 |
+
content: str
|
21 |
+
full_doc_id: str
|
22 |
+
chunk_order_index: int
|
23 |
+
|
24 |
|
25 |
T = TypeVar("T")
|
26 |
|
|
|
59 |
global_config: dict[str, Any]
|
60 |
|
61 |
async def index_done_callback(self):
|
62 |
+
"""Commit the storage operations after indexing"""
|
63 |
pass
|
64 |
|
65 |
async def query_done_callback(self):
|
66 |
+
"""Commit the storage operations after querying"""
|
67 |
pass
|
68 |
|
69 |
|
|
|
86 |
class BaseKVStorage(StorageNameSpace):
|
87 |
embedding_func: EmbeddingFunc
|
88 |
|
89 |
+
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
90 |
raise NotImplementedError
|
91 |
|
92 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
93 |
raise NotImplementedError
|
94 |
|
95 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
96 |
+
"""Return un-exist keys"""
|
97 |
raise NotImplementedError
|
98 |
|
99 |
async def upsert(self, data: dict[str, Any]) -> None:
|
lightrag/kg/json_kv_impl.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
from dataclasses import dataclass
|
4 |
-
from typing import Any
|
5 |
|
|
|
|
|
|
|
6 |
from lightrag.utils import (
|
7 |
-
logger,
|
8 |
load_json,
|
|
|
9 |
write_json,
|
10 |
)
|
11 |
-
from lightrag.base import (
|
12 |
-
BaseKVStorage,
|
13 |
-
)
|
14 |
|
15 |
|
16 |
@dataclass
|
@@ -25,8 +25,8 @@ class JsonKVStorage(BaseKVStorage):
|
|
25 |
async def index_done_callback(self):
|
26 |
write_json(self._data, self._file_name)
|
27 |
|
28 |
-
async def get_by_id(self, id: str) -> dict[str, Any]:
|
29 |
-
return self._data.get(id
|
30 |
|
31 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
32 |
return [
|
@@ -39,7 +39,7 @@ class JsonKVStorage(BaseKVStorage):
|
|
39 |
]
|
40 |
|
41 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
42 |
-
return
|
43 |
|
44 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
45 |
left_data = {k: v for k, v in data.items() if k not in self._data}
|
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
from dataclasses import dataclass
|
4 |
+
from typing import Any, Union
|
5 |
|
6 |
+
from lightrag.base import (
|
7 |
+
BaseKVStorage,
|
8 |
+
)
|
9 |
from lightrag.utils import (
|
|
|
10 |
load_json,
|
11 |
+
logger,
|
12 |
write_json,
|
13 |
)
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
@dataclass
|
|
|
25 |
async def index_done_callback(self):
|
26 |
write_json(self._data, self._file_name)
|
27 |
|
28 |
+
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
29 |
+
return self._data.get(id)
|
30 |
|
31 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
32 |
return [
|
|
|
39 |
]
|
40 |
|
41 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
42 |
+
return data - set(self._data.keys())
|
43 |
|
44 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
45 |
left_data = {k: v for k, v in data.items() if k not in self._data}
|
lightrag/kg/jsondocstatus_impl.py
CHANGED
@@ -76,7 +76,7 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|
76 |
|
77 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
78 |
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
79 |
-
return
|
80 |
|
81 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
82 |
result: list[dict[str, Any]] = []
|
|
|
76 |
|
77 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
78 |
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
79 |
+
return set(k for k in data if k not in self._data)
|
80 |
|
81 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
82 |
result: list[dict[str, Any]] = []
|
lightrag/kg/mongo_impl.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import os
|
2 |
-
from tqdm.asyncio import tqdm as tqdm_async
|
3 |
from dataclasses import dataclass
|
4 |
-
|
5 |
import numpy as np
|
|
|
|
|
6 |
|
7 |
if not pm.is_installed("pymongo"):
|
8 |
pm.install("pymongo")
|
@@ -10,13 +11,14 @@ if not pm.is_installed("pymongo"):
|
|
10 |
if not pm.is_installed("motor"):
|
11 |
pm.install("motor")
|
12 |
|
13 |
-
from
|
|
|
14 |
from motor.motor_asyncio import AsyncIOMotorClient
|
15 |
-
from
|
16 |
|
17 |
-
from ..
|
18 |
-
from ..base import BaseKVStorage, BaseGraphStorage
|
19 |
from ..namespace import NameSpace, is_namespace
|
|
|
20 |
|
21 |
|
22 |
@dataclass
|
@@ -29,7 +31,7 @@ class MongoKVStorage(BaseKVStorage):
|
|
29 |
self._data = database.get_collection(self.namespace)
|
30 |
logger.info(f"Use MongoDB as KV {self.namespace}")
|
31 |
|
32 |
-
async def get_by_id(self, id: str) -> dict[str, Any]:
|
33 |
return self._data.find_one({"_id": id})
|
34 |
|
35 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
@@ -170,7 +172,6 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
170 |
But typically for a direct edge, we might just do a find_one.
|
171 |
Below is a demonstration approach.
|
172 |
"""
|
173 |
-
|
174 |
# We can do a single-hop graphLookup (maxDepth=0 or 1).
|
175 |
# Then check if the target_node appears among the edges array.
|
176 |
pipeline = [
|
|
|
1 |
import os
|
|
|
2 |
from dataclasses import dataclass
|
3 |
+
|
4 |
import numpy as np
|
5 |
+
import pipmaster as pm
|
6 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
7 |
|
8 |
if not pm.is_installed("pymongo"):
|
9 |
pm.install("pymongo")
|
|
|
11 |
if not pm.is_installed("motor"):
|
12 |
pm.install("motor")
|
13 |
|
14 |
+
from typing import Any, List, Tuple, Union
|
15 |
+
|
16 |
from motor.motor_asyncio import AsyncIOMotorClient
|
17 |
+
from pymongo import MongoClient
|
18 |
|
19 |
+
from ..base import BaseGraphStorage, BaseKVStorage
|
|
|
20 |
from ..namespace import NameSpace, is_namespace
|
21 |
+
from ..utils import logger
|
22 |
|
23 |
|
24 |
@dataclass
|
|
|
31 |
self._data = database.get_collection(self.namespace)
|
32 |
logger.info(f"Use MongoDB as KV {self.namespace}")
|
33 |
|
34 |
+
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
35 |
return self._data.find_one({"_id": id})
|
36 |
|
37 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
|
|
172 |
But typically for a direct edge, we might just do a find_one.
|
173 |
Below is a demonstration approach.
|
174 |
"""
|
|
|
175 |
# We can do a single-hop graphLookup (maxDepth=0 or 1).
|
176 |
# Then check if the target_node appears among the edges array.
|
177 |
pipeline = [
|
lightrag/kg/oracle_impl.py
CHANGED
@@ -1,27 +1,28 @@
|
|
1 |
-
import
|
2 |
import asyncio
|
|
|
3 |
|
4 |
# import html
|
5 |
# import os
|
6 |
from dataclasses import dataclass
|
7 |
from typing import Any, Union
|
|
|
8 |
import numpy as np
|
9 |
-
import array
|
10 |
import pipmaster as pm
|
11 |
|
12 |
if not pm.is_installed("oracledb"):
|
13 |
pm.install("oracledb")
|
14 |
|
15 |
|
16 |
-
|
|
|
17 |
from ..base import (
|
18 |
BaseGraphStorage,
|
19 |
BaseKVStorage,
|
20 |
BaseVectorStorage,
|
21 |
)
|
22 |
from ..namespace import NameSpace, is_namespace
|
23 |
-
|
24 |
-
import oracledb
|
25 |
|
26 |
|
27 |
class OracleDB:
|
@@ -107,7 +108,7 @@ class OracleDB:
|
|
107 |
"SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
|
108 |
)
|
109 |
else:
|
110 |
-
await self.query("SELECT 1 FROM {k}"
|
111 |
except Exception as e:
|
112 |
logger.error(f"Failed to check table {k} in Oracle database")
|
113 |
logger.error(f"Oracle database error: {e}")
|
@@ -181,8 +182,8 @@ class OracleKVStorage(BaseKVStorage):
|
|
181 |
|
182 |
################ QUERY METHODS ################
|
183 |
|
184 |
-
async def get_by_id(self, id: str) -> dict[str, Any]:
|
185 |
-
"""
|
186 |
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
187 |
params = {"workspace": self.db.workspace, "id": id}
|
188 |
# print("get_by_id:"+SQL)
|
@@ -191,7 +192,10 @@ class OracleKVStorage(BaseKVStorage):
|
|
191 |
res = {}
|
192 |
for row in array_res:
|
193 |
res[row["id"]] = row
|
194 |
-
|
|
|
|
|
|
|
195 |
else:
|
196 |
return await self.db.query(SQL, params)
|
197 |
|
@@ -209,7 +213,7 @@ class OracleKVStorage(BaseKVStorage):
|
|
209 |
return None
|
210 |
|
211 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
212 |
-
"""
|
213 |
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
214 |
ids=",".join([f"'{id}'" for id in ids])
|
215 |
)
|
|
|
1 |
+
import array
|
2 |
import asyncio
|
3 |
+
import os
|
4 |
|
5 |
# import html
|
6 |
# import os
|
7 |
from dataclasses import dataclass
|
8 |
from typing import Any, Union
|
9 |
+
|
10 |
import numpy as np
|
|
|
11 |
import pipmaster as pm
|
12 |
|
13 |
if not pm.is_installed("oracledb"):
|
14 |
pm.install("oracledb")
|
15 |
|
16 |
|
17 |
+
import oracledb
|
18 |
+
|
19 |
from ..base import (
|
20 |
BaseGraphStorage,
|
21 |
BaseKVStorage,
|
22 |
BaseVectorStorage,
|
23 |
)
|
24 |
from ..namespace import NameSpace, is_namespace
|
25 |
+
from ..utils import logger
|
|
|
26 |
|
27 |
|
28 |
class OracleDB:
|
|
|
108 |
"SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
|
109 |
)
|
110 |
else:
|
111 |
+
await self.query(f"SELECT 1 FROM {k}")
|
112 |
except Exception as e:
|
113 |
logger.error(f"Failed to check table {k} in Oracle database")
|
114 |
logger.error(f"Oracle database error: {e}")
|
|
|
182 |
|
183 |
################ QUERY METHODS ################
|
184 |
|
185 |
+
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
186 |
+
"""Get doc_full data based on id."""
|
187 |
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
188 |
params = {"workspace": self.db.workspace, "id": id}
|
189 |
# print("get_by_id:"+SQL)
|
|
|
192 |
res = {}
|
193 |
for row in array_res:
|
194 |
res[row["id"]] = row
|
195 |
+
if res:
|
196 |
+
return res
|
197 |
+
else:
|
198 |
+
return None
|
199 |
else:
|
200 |
return await self.db.query(SQL, params)
|
201 |
|
|
|
213 |
return None
|
214 |
|
215 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
216 |
+
"""Get doc_chunks data based on id"""
|
217 |
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
218 |
ids=",".join([f"'{id}'" for id in ids])
|
219 |
)
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -4,34 +4,35 @@ import json
|
|
4 |
import os
|
5 |
import time
|
6 |
from dataclasses import dataclass
|
7 |
-
from typing import
|
8 |
-
import numpy as np
|
9 |
|
|
|
10 |
import pipmaster as pm
|
11 |
|
12 |
if not pm.is_installed("asyncpg"):
|
13 |
pm.install("asyncpg")
|
14 |
|
15 |
-
import asyncpg
|
16 |
import sys
|
17 |
-
|
|
|
18 |
from tenacity import (
|
19 |
retry,
|
20 |
retry_if_exception_type,
|
21 |
stop_after_attempt,
|
22 |
wait_exponential,
|
23 |
)
|
|
|
24 |
|
25 |
-
from ..utils import logger
|
26 |
from ..base import (
|
|
|
27 |
BaseKVStorage,
|
28 |
BaseVectorStorage,
|
29 |
-
DocStatusStorage,
|
30 |
-
DocStatus,
|
31 |
DocProcessingStatus,
|
32 |
-
|
|
|
33 |
)
|
34 |
from ..namespace import NameSpace, is_namespace
|
|
|
35 |
|
36 |
if sys.platform.startswith("win"):
|
37 |
import asyncio.windows_events
|
@@ -82,7 +83,7 @@ class PostgreSQLDB:
|
|
82 |
async def check_tables(self):
|
83 |
for k, v in TABLES.items():
|
84 |
try:
|
85 |
-
await self.query("SELECT 1 FROM {k} LIMIT 1"
|
86 |
except Exception as e:
|
87 |
logger.error(f"Failed to check table {k} in PostgreSQL database")
|
88 |
logger.error(f"PostgreSQL database error: {e}")
|
@@ -183,7 +184,7 @@ class PGKVStorage(BaseKVStorage):
|
|
183 |
|
184 |
################ QUERY METHODS ################
|
185 |
|
186 |
-
async def get_by_id(self, id: str) -> dict[str, Any]:
|
187 |
"""Get doc_full data by id."""
|
188 |
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
189 |
params = {"workspace": self.db.workspace, "id": id}
|
@@ -192,9 +193,10 @@ class PGKVStorage(BaseKVStorage):
|
|
192 |
res = {}
|
193 |
for row in array_res:
|
194 |
res[row["id"]] = row
|
195 |
-
return res
|
196 |
else:
|
197 |
-
|
|
|
198 |
|
199 |
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
200 |
"""Specifically for llm_response_cache."""
|
@@ -435,12 +437,12 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
435 |
existed = set([element["id"] for element in result])
|
436 |
return set(data) - existed
|
437 |
|
438 |
-
async def get_by_id(self, id: str) -> dict[str, Any]:
|
439 |
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
440 |
params = {"workspace": self.db.workspace, "id": id}
|
441 |
result = await self.db.query(sql, params, True)
|
442 |
if result is None or result == []:
|
443 |
-
return
|
444 |
else:
|
445 |
return DocProcessingStatus(
|
446 |
content=result[0]["content"],
|
|
|
4 |
import os
|
5 |
import time
|
6 |
from dataclasses import dataclass
|
7 |
+
from typing import Any, Dict, List, Set, Tuple, Union
|
|
|
8 |
|
9 |
+
import numpy as np
|
10 |
import pipmaster as pm
|
11 |
|
12 |
if not pm.is_installed("asyncpg"):
|
13 |
pm.install("asyncpg")
|
14 |
|
|
|
15 |
import sys
|
16 |
+
|
17 |
+
import asyncpg
|
18 |
from tenacity import (
|
19 |
retry,
|
20 |
retry_if_exception_type,
|
21 |
stop_after_attempt,
|
22 |
wait_exponential,
|
23 |
)
|
24 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
25 |
|
|
|
26 |
from ..base import (
|
27 |
+
BaseGraphStorage,
|
28 |
BaseKVStorage,
|
29 |
BaseVectorStorage,
|
|
|
|
|
30 |
DocProcessingStatus,
|
31 |
+
DocStatus,
|
32 |
+
DocStatusStorage,
|
33 |
)
|
34 |
from ..namespace import NameSpace, is_namespace
|
35 |
+
from ..utils import logger
|
36 |
|
37 |
if sys.platform.startswith("win"):
|
38 |
import asyncio.windows_events
|
|
|
83 |
async def check_tables(self):
|
84 |
for k, v in TABLES.items():
|
85 |
try:
|
86 |
+
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
|
87 |
except Exception as e:
|
88 |
logger.error(f"Failed to check table {k} in PostgreSQL database")
|
89 |
logger.error(f"PostgreSQL database error: {e}")
|
|
|
184 |
|
185 |
################ QUERY METHODS ################
|
186 |
|
187 |
+
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
188 |
"""Get doc_full data by id."""
|
189 |
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
190 |
params = {"workspace": self.db.workspace, "id": id}
|
|
|
193 |
res = {}
|
194 |
for row in array_res:
|
195 |
res[row["id"]] = row
|
196 |
+
return res if res else None
|
197 |
else:
|
198 |
+
response = await self.db.query(sql, params)
|
199 |
+
return response if response else None
|
200 |
|
201 |
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
202 |
"""Specifically for llm_response_cache."""
|
|
|
437 |
existed = set([element["id"] for element in result])
|
438 |
return set(data) - existed
|
439 |
|
440 |
+
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
441 |
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
|
442 |
params = {"workspace": self.db.workspace, "id": id}
|
443 |
result = await self.db.query(sql, params, True)
|
444 |
if result is None or result == []:
|
445 |
+
return None
|
446 |
else:
|
447 |
return DocProcessingStatus(
|
448 |
content=result[0]["content"],
|
lightrag/kg/redis_impl.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from typing import Any
|
3 |
from tqdm.asyncio import tqdm as tqdm_async
|
4 |
from dataclasses import dataclass
|
5 |
import pipmaster as pm
|
@@ -21,7 +21,7 @@ class RedisKVStorage(BaseKVStorage):
|
|
21 |
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
22 |
logger.info(f"Use Redis as KV {self.namespace}")
|
23 |
|
24 |
-
async def get_by_id(self, id):
|
25 |
data = await self._redis.get(f"{self.namespace}:{id}")
|
26 |
return json.loads(data) if data else None
|
27 |
|
|
|
1 |
import os
|
2 |
+
from typing import Any, Union
|
3 |
from tqdm.asyncio import tqdm as tqdm_async
|
4 |
from dataclasses import dataclass
|
5 |
import pipmaster as pm
|
|
|
21 |
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
22 |
logger.info(f"Use Redis as KV {self.namespace}")
|
23 |
|
24 |
+
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
25 |
data = await self._redis.get(f"{self.namespace}:{id}")
|
26 |
return json.loads(data) if data else None
|
27 |
|
lightrag/kg/tidb_impl.py
CHANGED
@@ -14,12 +14,12 @@ if not pm.is_installed("sqlalchemy"):
|
|
14 |
from sqlalchemy import create_engine, text
|
15 |
from tqdm import tqdm
|
16 |
|
17 |
-
from ..base import
|
18 |
-
from ..utils import logger
|
19 |
from ..namespace import NameSpace, is_namespace
|
|
|
20 |
|
21 |
|
22 |
-
class TiDB
|
23 |
def __init__(self, config, **kwargs):
|
24 |
self.host = config.get("host", None)
|
25 |
self.port = config.get("port", None)
|
@@ -108,12 +108,12 @@ class TiDBKVStorage(BaseKVStorage):
|
|
108 |
|
109 |
################ QUERY METHODS ################
|
110 |
|
111 |
-
async def get_by_id(self, id: str) -> dict[str, Any]:
|
112 |
"""Fetch doc_full data by id."""
|
113 |
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
114 |
params = {"id": id}
|
115 |
-
|
116 |
-
return
|
117 |
|
118 |
# Query by id
|
119 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
@@ -178,7 +178,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|
178 |
"tokens": item["tokens"],
|
179 |
"chunk_order_index": item["chunk_order_index"],
|
180 |
"full_doc_id": item["full_doc_id"],
|
181 |
-
"content_vector": f
|
182 |
"workspace": self.db.workspace,
|
183 |
}
|
184 |
)
|
@@ -222,8 +222,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
222 |
)
|
223 |
|
224 |
async def query(self, query: str, top_k: int) -> list[dict]:
|
225 |
-
"""
|
226 |
-
|
227 |
embeddings = await self.embedding_func([query])
|
228 |
embedding = embeddings[0]
|
229 |
|
@@ -286,7 +285,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
286 |
"id": item["id"],
|
287 |
"name": item["entity_name"],
|
288 |
"content": item["content"],
|
289 |
-
"content_vector": f
|
290 |
"workspace": self.db.workspace,
|
291 |
}
|
292 |
# update entity_id if node inserted by graph_storage_instance before
|
@@ -308,7 +307,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
308 |
"source_name": item["src_id"],
|
309 |
"target_name": item["tgt_id"],
|
310 |
"content": item["content"],
|
311 |
-
"content_vector": f
|
312 |
"workspace": self.db.workspace,
|
313 |
}
|
314 |
# update relation_id if node inserted by graph_storage_instance before
|
|
|
14 |
from sqlalchemy import create_engine, text
|
15 |
from tqdm import tqdm
|
16 |
|
17 |
+
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
|
|
18 |
from ..namespace import NameSpace, is_namespace
|
19 |
+
from ..utils import logger
|
20 |
|
21 |
|
22 |
+
class TiDB:
|
23 |
def __init__(self, config, **kwargs):
|
24 |
self.host = config.get("host", None)
|
25 |
self.port = config.get("port", None)
|
|
|
108 |
|
109 |
################ QUERY METHODS ################
|
110 |
|
111 |
+
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
112 |
"""Fetch doc_full data by id."""
|
113 |
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
114 |
params = {"id": id}
|
115 |
+
response = await self.db.query(SQL, params)
|
116 |
+
return response if response else None
|
117 |
|
118 |
# Query by id
|
119 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
|
|
178 |
"tokens": item["tokens"],
|
179 |
"chunk_order_index": item["chunk_order_index"],
|
180 |
"full_doc_id": item["full_doc_id"],
|
181 |
+
"content_vector": f'{item["__vector__"].tolist()}',
|
182 |
"workspace": self.db.workspace,
|
183 |
}
|
184 |
)
|
|
|
222 |
)
|
223 |
|
224 |
async def query(self, query: str, top_k: int) -> list[dict]:
|
225 |
+
"""Search from tidb vector"""
|
|
|
226 |
embeddings = await self.embedding_func([query])
|
227 |
embedding = embeddings[0]
|
228 |
|
|
|
285 |
"id": item["id"],
|
286 |
"name": item["entity_name"],
|
287 |
"content": item["content"],
|
288 |
+
"content_vector": f'{item["content_vector"].tolist()}',
|
289 |
"workspace": self.db.workspace,
|
290 |
}
|
291 |
# update entity_id if node inserted by graph_storage_instance before
|
|
|
307 |
"source_name": item["src_id"],
|
308 |
"target_name": item["tgt_id"],
|
309 |
"content": item["content"],
|
310 |
+
"content_vector": f'{item["content_vector"].tolist()}',
|
311 |
"workspace": self.db.workspace,
|
312 |
}
|
313 |
# update relation_id if node inserted by graph_storage_instance before
|