yangdx
commited on
Commit
·
b89f76b
1
Parent(s):
bfe6274
Fix linting
Browse files- lightrag/api/lightrag_server.py +22 -7
- lightrag/kg/chroma_impl.py +3 -2
- lightrag/kg/faiss_impl.py +3 -1
- lightrag/kg/milvus_impl.py +7 -2
- lightrag/kg/nano_vector_db_impl.py +3 -1
- lightrag/kg/oracle_impl.py +3 -2
- lightrag/kg/postgres_impl.py +4 -4
- lightrag/kg/qdrant_impl.py +9 -3
- lightrag/kg/tidb_impl.py +3 -1
- lightrag/lightrag.py +1 -3
- lightrag/operate.py +6 -4
lightrag/api/lightrag_server.py
CHANGED
@@ -66,12 +66,14 @@ load_dotenv(override=True)
|
|
66 |
config = configparser.ConfigParser()
|
67 |
config.read("config.ini")
|
68 |
|
|
|
69 |
class DefaultRAGStorageConfig:
|
70 |
KV_STORAGE = "JsonKVStorage"
|
71 |
VECTOR_STORAGE = "NanoVectorDBStorage"
|
72 |
GRAPH_STORAGE = "NetworkXStorage"
|
73 |
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
74 |
|
|
|
75 |
# Global progress tracker
|
76 |
scan_progress: Dict = {
|
77 |
"is_scanning": False,
|
@@ -317,22 +319,30 @@ def parse_args() -> argparse.Namespace:
|
|
317 |
|
318 |
parser.add_argument(
|
319 |
"--kv-storage",
|
320 |
-
default=get_env_value(
|
|
|
|
|
321 |
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
|
322 |
)
|
323 |
parser.add_argument(
|
324 |
"--doc-status-storage",
|
325 |
-
default=get_env_value(
|
|
|
|
|
326 |
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
|
327 |
)
|
328 |
parser.add_argument(
|
329 |
"--graph-storage",
|
330 |
-
default=get_env_value(
|
|
|
|
|
331 |
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
|
332 |
)
|
333 |
parser.add_argument(
|
334 |
"--vector-storage",
|
335 |
-
default=get_env_value(
|
|
|
|
|
336 |
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
|
337 |
)
|
338 |
|
@@ -725,7 +735,12 @@ def create_app(args):
|
|
725 |
for storage_name, storage_instance in storage_instances:
|
726 |
if isinstance(
|
727 |
storage_instance,
|
728 |
-
(
|
|
|
|
|
|
|
|
|
|
|
729 |
):
|
730 |
storage_instance.db = postgres_db
|
731 |
logger.info(f"Injected postgres_db to {storage_name}")
|
@@ -790,11 +805,11 @@ def create_app(args):
|
|
790 |
if postgres_db and hasattr(postgres_db, "pool"):
|
791 |
await postgres_db.pool.close()
|
792 |
logger.info("Closed PostgreSQL connection pool")
|
793 |
-
|
794 |
if oracle_db and hasattr(oracle_db, "pool"):
|
795 |
await oracle_db.pool.close()
|
796 |
logger.info("Closed Oracle connection pool")
|
797 |
-
|
798 |
if tidb_db and hasattr(tidb_db, "pool"):
|
799 |
await tidb_db.pool.close()
|
800 |
logger.info("Closed TiDB connection pool")
|
|
|
66 |
config = configparser.ConfigParser()
|
67 |
config.read("config.ini")
|
68 |
|
69 |
+
|
70 |
class DefaultRAGStorageConfig:
|
71 |
KV_STORAGE = "JsonKVStorage"
|
72 |
VECTOR_STORAGE = "NanoVectorDBStorage"
|
73 |
GRAPH_STORAGE = "NetworkXStorage"
|
74 |
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
75 |
|
76 |
+
|
77 |
# Global progress tracker
|
78 |
scan_progress: Dict = {
|
79 |
"is_scanning": False,
|
|
|
319 |
|
320 |
parser.add_argument(
|
321 |
"--kv-storage",
|
322 |
+
default=get_env_value(
|
323 |
+
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
324 |
+
),
|
325 |
help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})",
|
326 |
)
|
327 |
parser.add_argument(
|
328 |
"--doc-status-storage",
|
329 |
+
default=get_env_value(
|
330 |
+
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
331 |
+
),
|
332 |
help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
|
333 |
)
|
334 |
parser.add_argument(
|
335 |
"--graph-storage",
|
336 |
+
default=get_env_value(
|
337 |
+
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
338 |
+
),
|
339 |
help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
|
340 |
)
|
341 |
parser.add_argument(
|
342 |
"--vector-storage",
|
343 |
+
default=get_env_value(
|
344 |
+
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
345 |
+
),
|
346 |
help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
|
347 |
)
|
348 |
|
|
|
735 |
for storage_name, storage_instance in storage_instances:
|
736 |
if isinstance(
|
737 |
storage_instance,
|
738 |
+
(
|
739 |
+
PGKVStorage,
|
740 |
+
PGVectorStorage,
|
741 |
+
PGGraphStorage,
|
742 |
+
PGDocStatusStorage,
|
743 |
+
),
|
744 |
):
|
745 |
storage_instance.db = postgres_db
|
746 |
logger.info(f"Injected postgres_db to {storage_name}")
|
|
|
805 |
if postgres_db and hasattr(postgres_db, "pool"):
|
806 |
await postgres_db.pool.close()
|
807 |
logger.info("Closed PostgreSQL connection pool")
|
808 |
+
|
809 |
if oracle_db and hasattr(oracle_db, "pool"):
|
810 |
await oracle_db.pool.close()
|
811 |
logger.info("Closed Oracle connection pool")
|
812 |
+
|
813 |
if tidb_db and hasattr(tidb_db, "pool"):
|
814 |
await tidb_db.pool.close()
|
815 |
logger.info("Closed TiDB connection pool")
|
lightrag/kg/chroma_impl.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import os
|
2 |
import asyncio
|
3 |
from dataclasses import dataclass
|
4 |
from typing import Union
|
@@ -20,7 +19,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|
20 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
21 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
22 |
if cosine_threshold is None:
|
23 |
-
raise ValueError(
|
|
|
|
|
24 |
self.cosine_better_than_threshold = cosine_threshold
|
25 |
|
26 |
user_collection_settings = config.get("collection_settings", {})
|
|
|
|
|
1 |
import asyncio
|
2 |
from dataclasses import dataclass
|
3 |
from typing import Union
|
|
|
19 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
20 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
21 |
if cosine_threshold is None:
|
22 |
+
raise ValueError(
|
23 |
+
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
24 |
+
)
|
25 |
self.cosine_better_than_threshold = cosine_threshold
|
26 |
|
27 |
user_collection_settings = config.get("collection_settings", {})
|
lightrag/kg/faiss_impl.py
CHANGED
@@ -30,7 +30,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
30 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
31 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
32 |
if cosine_threshold is None:
|
33 |
-
raise ValueError(
|
|
|
|
|
34 |
self.cosine_better_than_threshold = cosine_threshold
|
35 |
|
36 |
# Where to save index file if you want persistent storage
|
|
|
30 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
31 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
32 |
if cosine_threshold is None:
|
33 |
+
raise ValueError(
|
34 |
+
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
35 |
+
)
|
36 |
self.cosine_better_than_threshold = cosine_threshold
|
37 |
|
38 |
# Where to save index file if you want persistent storage
|
lightrag/kg/milvus_impl.py
CHANGED
@@ -35,7 +35,9 @@ class MilvusVectorDBStorge(BaseVectorStorage):
|
|
35 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
36 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
37 |
if cosine_threshold is None:
|
38 |
-
raise ValueError(
|
|
|
|
|
39 |
self.cosine_better_than_threshold = cosine_threshold
|
40 |
|
41 |
self._client = MilvusClient(
|
@@ -111,7 +113,10 @@ class MilvusVectorDBStorge(BaseVectorStorage):
|
|
111 |
data=embedding,
|
112 |
limit=top_k,
|
113 |
output_fields=list(self.meta_fields),
|
114 |
-
search_params={
|
|
|
|
|
|
|
115 |
)
|
116 |
print(results)
|
117 |
return [
|
|
|
35 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
36 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
37 |
if cosine_threshold is None:
|
38 |
+
raise ValueError(
|
39 |
+
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
40 |
+
)
|
41 |
self.cosine_better_than_threshold = cosine_threshold
|
42 |
|
43 |
self._client = MilvusClient(
|
|
|
113 |
data=embedding,
|
114 |
limit=top_k,
|
115 |
output_fields=list(self.meta_fields),
|
116 |
+
search_params={
|
117 |
+
"metric_type": "COSINE",
|
118 |
+
"params": {"radius": self.cosine_better_than_threshold},
|
119 |
+
},
|
120 |
)
|
121 |
print(results)
|
122 |
return [
|
lightrag/kg/nano_vector_db_impl.py
CHANGED
@@ -82,7 +82,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
82 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
83 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
84 |
if cosine_threshold is None:
|
85 |
-
raise ValueError(
|
|
|
|
|
86 |
self.cosine_better_than_threshold = cosine_threshold
|
87 |
|
88 |
self._client_file_name = os.path.join(
|
|
|
82 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
83 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
84 |
if cosine_threshold is None:
|
85 |
+
raise ValueError(
|
86 |
+
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
87 |
+
)
|
88 |
self.cosine_better_than_threshold = cosine_threshold
|
89 |
|
90 |
self._client_file_name = os.path.join(
|
lightrag/kg/oracle_impl.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import array
|
2 |
import asyncio
|
3 |
-
import os
|
4 |
|
5 |
# import html
|
6 |
# import os
|
@@ -326,7 +325,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|
326 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
327 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
328 |
if cosine_threshold is None:
|
329 |
-
raise ValueError(
|
|
|
|
|
330 |
self.cosine_better_than_threshold = cosine_threshold
|
331 |
|
332 |
async def upsert(self, data: dict[str, dict]):
|
|
|
1 |
import array
|
2 |
import asyncio
|
|
|
3 |
|
4 |
# import html
|
5 |
# import os
|
|
|
325 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
326 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
327 |
if cosine_threshold is None:
|
328 |
+
raise ValueError(
|
329 |
+
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
330 |
+
)
|
331 |
self.cosine_better_than_threshold = cosine_threshold
|
332 |
|
333 |
async def upsert(self, data: dict[str, dict]):
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -306,7 +306,9 @@ class PGVectorStorage(BaseVectorStorage):
|
|
306 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
307 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
308 |
if cosine_threshold is None:
|
309 |
-
raise ValueError(
|
|
|
|
|
310 |
self.cosine_better_than_threshold = cosine_threshold
|
311 |
|
312 |
def _upsert_chunks(self, item: dict):
|
@@ -424,9 +426,7 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
424 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
425 |
"""Return keys that don't exist in storage"""
|
426 |
keys = ",".join([f"'{_id}'" for _id in data])
|
427 |
-
sql = (
|
428 |
-
f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
429 |
-
)
|
430 |
result = await self.db.query(sql, multirows=True)
|
431 |
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
432 |
if result is None:
|
|
|
306 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
307 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
308 |
if cosine_threshold is None:
|
309 |
+
raise ValueError(
|
310 |
+
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
311 |
+
)
|
312 |
self.cosine_better_than_threshold = cosine_threshold
|
313 |
|
314 |
def _upsert_chunks(self, item: dict):
|
|
|
426 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
427 |
"""Return keys that don't exist in storage"""
|
428 |
keys = ",".join([f"'{_id}'" for _id in data])
|
429 |
+
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})"
|
|
|
|
|
430 |
result = await self.db.query(sql, multirows=True)
|
431 |
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
432 |
if result is None:
|
lightrag/kg/qdrant_impl.py
CHANGED
@@ -64,7 +64,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|
64 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
65 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
66 |
if cosine_threshold is None:
|
67 |
-
raise ValueError(
|
|
|
|
|
68 |
self.cosine_better_than_threshold = cosine_threshold
|
69 |
|
70 |
self._client = QdrantClient(
|
@@ -140,5 +142,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|
140 |
)
|
141 |
logger.debug(f"query result: {results}")
|
142 |
# 添加余弦相似度过滤
|
143 |
-
filtered_results = [
|
144 |
-
|
|
|
|
|
|
|
|
|
|
64 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
65 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
66 |
if cosine_threshold is None:
|
67 |
+
raise ValueError(
|
68 |
+
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
69 |
+
)
|
70 |
self.cosine_better_than_threshold = cosine_threshold
|
71 |
|
72 |
self._client = QdrantClient(
|
|
|
142 |
)
|
143 |
logger.debug(f"query result: {results}")
|
144 |
# 添加余弦相似度过滤
|
145 |
+
filtered_results = [
|
146 |
+
dp for dp in results if dp.score >= self.cosine_better_than_threshold
|
147 |
+
]
|
148 |
+
return [
|
149 |
+
{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results
|
150 |
+
]
|
lightrag/kg/tidb_impl.py
CHANGED
@@ -222,7 +222,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
222 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
223 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
224 |
if cosine_threshold is None:
|
225 |
-
raise ValueError(
|
|
|
|
|
226 |
self.cosine_better_than_threshold = cosine_threshold
|
227 |
|
228 |
async def query(self, query: str, top_k: int) -> list[dict]:
|
|
|
222 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
223 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
224 |
if cosine_threshold is None:
|
225 |
+
raise ValueError(
|
226 |
+
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
227 |
+
)
|
228 |
self.cosine_better_than_threshold = cosine_threshold
|
229 |
|
230 |
async def query(self, query: str, top_k: int) -> list[dict]:
|
lightrag/lightrag.py
CHANGED
@@ -426,7 +426,7 @@ class LightRAG:
|
|
426 |
}
|
427 |
self.vector_db_storage_cls_kwargs = {
|
428 |
**default_vector_db_kwargs,
|
429 |
-
**self.vector_db_storage_cls_kwargs
|
430 |
}
|
431 |
|
432 |
# show config
|
@@ -532,8 +532,6 @@ class LightRAG:
|
|
532 |
embedding_func=self.embedding_func,
|
533 |
)
|
534 |
|
535 |
-
|
536 |
-
|
537 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
538 |
partial(
|
539 |
self.llm_model_func,
|
|
|
426 |
}
|
427 |
self.vector_db_storage_cls_kwargs = {
|
428 |
**default_vector_db_kwargs,
|
429 |
+
**self.vector_db_storage_cls_kwargs,
|
430 |
}
|
431 |
|
432 |
# show config
|
|
|
532 |
embedding_func=self.embedding_func,
|
533 |
)
|
534 |
|
|
|
|
|
535 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
536 |
partial(
|
537 |
self.llm_model_func,
|
lightrag/operate.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import asyncio
|
2 |
import json
|
3 |
-
import os
|
4 |
import re
|
5 |
from tqdm.asyncio import tqdm as tqdm_async
|
6 |
from typing import Any, Union
|
@@ -35,7 +34,6 @@ from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
|
35 |
import time
|
36 |
|
37 |
|
38 |
-
|
39 |
def chunking_by_token_size(
|
40 |
content: str,
|
41 |
split_by_character: Union[str, None] = None,
|
@@ -1057,7 +1055,9 @@ async def _get_node_data(
|
|
1057 |
query_param: QueryParam,
|
1058 |
):
|
1059 |
# get similar entities
|
1060 |
-
logger.info(
|
|
|
|
|
1061 |
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
1062 |
if not len(results):
|
1063 |
return "", "", ""
|
@@ -1273,7 +1273,9 @@ async def _get_edge_data(
|
|
1273 |
text_chunks_db: BaseKVStorage,
|
1274 |
query_param: QueryParam,
|
1275 |
):
|
1276 |
-
logger.info(
|
|
|
|
|
1277 |
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
1278 |
|
1279 |
if not len(results):
|
|
|
1 |
import asyncio
|
2 |
import json
|
|
|
3 |
import re
|
4 |
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
from typing import Any, Union
|
|
|
34 |
import time
|
35 |
|
36 |
|
|
|
37 |
def chunking_by_token_size(
|
38 |
content: str,
|
39 |
split_by_character: Union[str, None] = None,
|
|
|
1055 |
query_param: QueryParam,
|
1056 |
):
|
1057 |
# get similar entities
|
1058 |
+
logger.info(
|
1059 |
+
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
|
1060 |
+
)
|
1061 |
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
1062 |
if not len(results):
|
1063 |
return "", "", ""
|
|
|
1273 |
text_chunks_db: BaseKVStorage,
|
1274 |
query_param: QueryParam,
|
1275 |
):
|
1276 |
+
logger.info(
|
1277 |
+
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
|
1278 |
+
)
|
1279 |
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
1280 |
|
1281 |
if not len(results):
|