Merge branch 'main' into main
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README-zh.md +52 -4
- README.md +54 -5
- config.ini.example +0 -17
- env.example +20 -27
- examples/lightrag_api_ollama_demo.py +0 -188
- examples/lightrag_api_openai_compatible_demo.py +0 -204
- examples/lightrag_api_oracle_demo.py +0 -267
- examples/lightrag_ollama_gremlin_demo.py +4 -0
- examples/lightrag_oracle_demo.py +0 -141
- examples/lightrag_tidb_demo.py +4 -0
- lightrag/api/README-zh.md +4 -12
- lightrag/api/README.md +5 -13
- lightrag/api/__init__.py +1 -1
- lightrag/api/auth.py +9 -8
- lightrag/api/config.py +335 -0
- lightrag/api/lightrag_server.py +39 -19
- lightrag/api/routers/document_routes.py +465 -52
- lightrag/api/routers/graph_routes.py +10 -14
- lightrag/api/run_with_gunicorn.py +31 -25
- lightrag/api/utils_api.py +19 -354
- lightrag/api/webui/assets/index-CD5HxTy1.css +0 -0
- lightrag/api/webui/assets/{index-raheqJeu.js → index-Cma7xY0-.js} +0 -0
- lightrag/api/webui/assets/index-QU59h9JG.css +0 -0
- lightrag/api/webui/index.html +0 -0
- lightrag/base.py +122 -9
- lightrag/kg/__init__.py +15 -40
- lightrag/kg/age_impl.py +21 -3
- lightrag/kg/chroma_impl.py +28 -2
- lightrag/kg/faiss_impl.py +64 -5
- lightrag/kg/gremlin_impl.py +24 -3
- lightrag/kg/json_doc_status_impl.py +49 -10
- lightrag/kg/json_kv_impl.py +73 -3
- lightrag/kg/milvus_impl.py +31 -1
- lightrag/kg/mongo_impl.py +130 -3
- lightrag/kg/nano_vector_db_impl.py +66 -0
- lightrag/kg/neo4j_impl.py +373 -246
- lightrag/kg/networkx_impl.py +122 -97
- lightrag/kg/oracle_impl.py +0 -1346
- lightrag/kg/postgres_impl.py +382 -319
- lightrag/kg/qdrant_impl.py +93 -6
- lightrag/kg/redis_impl.py +46 -60
- lightrag/kg/tidb_impl.py +172 -3
- lightrag/lightrag.py +25 -42
- lightrag/llm/openai.py +101 -26
- lightrag/operate.py +105 -98
- lightrag/types.py +1 -0
- lightrag_webui/src/App.tsx +89 -36
- lightrag_webui/src/AppRouter.tsx +6 -1
- lightrag_webui/src/api/lightrag.ts +16 -2
- lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx +119 -15
README-zh.md
CHANGED
|
@@ -11,7 +11,6 @@
|
|
| 11 |
- [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
|
| 12 |
- [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。
|
| 13 |
- [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。
|
| 14 |
-
- [X] [2024.11.12]🎯📢LightRAG现在支持[Oracle Database 23ai的所有存储类型(KV、向量和图)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py)。
|
| 15 |
- [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
|
| 16 |
- [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app),允许您插入、查询、可视化和下载LightRAG知识。
|
| 17 |
- [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。
|
|
@@ -410,6 +409,54 @@ if __name__ == "__main__":
|
|
| 410 |
|
| 411 |
</details>
|
| 412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
### 对话历史
|
| 414 |
|
| 415 |
LightRAG现在通过对话历史功能支持多轮对话。以下是使用方法:
|
|
@@ -1037,9 +1084,10 @@ rag.clear_cache(modes=["local"])
|
|
| 1037 |
| **参数** | **类型** | **说明** | **默认值** |
|
| 1038 |
|--------------|----------|-----------------|-------------|
|
| 1039 |
| **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
|
| 1040 |
-
| **kv_storage** | `str` |
|
| 1041 |
-
| **vector_storage** | `str` |
|
| 1042 |
-
| **graph_storage** | `str` |
|
|
|
|
| 1043 |
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
|
| 1044 |
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
|
| 1045 |
| **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
|
|
|
|
| 11 |
- [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
|
| 12 |
- [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。
|
| 13 |
- [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。
|
|
|
|
| 14 |
- [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
|
| 15 |
- [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app),允许您插入、查询、可视化和下载LightRAG知识。
|
| 16 |
- [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。
|
|
|
|
| 409 |
|
| 410 |
</details>
|
| 411 |
|
| 412 |
+
### Token统计功能
|
| 413 |
+
<details>
|
| 414 |
+
<summary> <b>概述和使用</b> </summary>
|
| 415 |
+
|
| 416 |
+
LightRAG提供了TokenTracker工具来跟踪和管理大模型的token消耗。这个功能对于控制API成本和优化性能特别有用。
|
| 417 |
+
|
| 418 |
+
#### 使用方法
|
| 419 |
+
|
| 420 |
+
```python
|
| 421 |
+
from lightrag.utils import TokenTracker
|
| 422 |
+
|
| 423 |
+
# 创建TokenTracker实例
|
| 424 |
+
token_tracker = TokenTracker()
|
| 425 |
+
|
| 426 |
+
# 方法1:使用上下文管理器(推荐)
|
| 427 |
+
# 适用于需要自动跟踪token使用的场景
|
| 428 |
+
with token_tracker:
|
| 429 |
+
result1 = await llm_model_func("你的问题1")
|
| 430 |
+
result2 = await llm_model_func("你的问题2")
|
| 431 |
+
|
| 432 |
+
# 方法2:手动添加token使用记录
|
| 433 |
+
# 适用于需要更精细控制token统计的场景
|
| 434 |
+
token_tracker.reset()
|
| 435 |
+
|
| 436 |
+
rag.insert()
|
| 437 |
+
|
| 438 |
+
rag.query("你的问题1", param=QueryParam(mode="naive"))
|
| 439 |
+
rag.query("你的问题2", param=QueryParam(mode="mix"))
|
| 440 |
+
|
| 441 |
+
# 显示总token使用量(包含插入和查询操作)
|
| 442 |
+
print("Token usage:", token_tracker.get_usage())
|
| 443 |
+
```
|
| 444 |
+
|
| 445 |
+
#### 使用建议
|
| 446 |
+
- 在长会话或批量操作中使用上下文管理器,可以自动跟踪所有token消耗
|
| 447 |
+
- 对于需要分段统计的场景,使用手动模式并适时调用reset()
|
| 448 |
+
- 定期检查token使用情况,有助于及时发现异常消耗
|
| 449 |
+
- 在开发测试阶段积极使用此功能,以便优化生产环境的成本
|
| 450 |
+
|
| 451 |
+
#### 实际应用示例
|
| 452 |
+
您可以参考以下示例来实现token统计:
|
| 453 |
+
- `examples/lightrag_gemini_track_token_demo.py`:使用Google Gemini模型的token统计示例
|
| 454 |
+
- `examples/lightrag_siliconcloud_track_token_demo.py`:使用SiliconCloud模型的token统计示例
|
| 455 |
+
|
| 456 |
+
这些示例展示了如何在不同模型和场景下有效地使用TokenTracker功能。
|
| 457 |
+
|
| 458 |
+
</details>
|
| 459 |
+
|
| 460 |
### 对话历史
|
| 461 |
|
| 462 |
LightRAG现在通过对话历史功能支持多轮对话。以下是使用方法:
|
|
|
|
| 1084 |
| **参数** | **类型** | **说明** | **默认值** |
|
| 1085 |
|--------------|----------|-----------------|-------------|
|
| 1086 |
| **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
|
| 1087 |
+
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
|
| 1088 |
+
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
|
| 1089 |
+
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
|
| 1090 |
+
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
| 1091 |
| **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
|
| 1092 |
| **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
|
| 1093 |
| **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
|
README.md
CHANGED
|
@@ -41,7 +41,6 @@
|
|
| 41 |
- [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
|
| 42 |
- [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
|
| 43 |
- [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
|
| 44 |
-
- [X] [2024.11.12]🎯📢LightRAG now supports [Oracle Database 23ai for all storage types (KV, vector, and graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py).
|
| 45 |
- [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
|
| 46 |
- [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
|
| 47 |
- [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
|
|
@@ -443,6 +442,55 @@ if __name__ == "__main__":
|
|
| 443 |
|
| 444 |
</details>
|
| 445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
### Conversation History Support
|
| 447 |
|
| 448 |
|
|
@@ -607,7 +655,7 @@ The `apipeline_enqueue_documents` and `apipeline_process_enqueue_documents` func
|
|
| 607 |
|
| 608 |
This is useful for scenarios where you want to process documents in the background while still allowing the main thread to continue executing.
|
| 609 |
|
| 610 |
-
And using a routine to process
|
| 611 |
|
| 612 |
```python
|
| 613 |
rag = LightRAG(..)
|
|
@@ -1096,9 +1144,10 @@ Valid modes are:
|
|
| 1096 |
| **Parameter** | **Type** | **Explanation** | **Default** |
|
| 1097 |
|--------------|----------|-----------------|-------------|
|
| 1098 |
| **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
|
| 1099 |
-
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage
|
| 1100 |
-
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage
|
| 1101 |
-
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage
|
|
|
|
| 1102 |
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
|
| 1103 |
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
|
| 1104 |
| **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
|
|
|
|
| 41 |
- [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
|
| 42 |
- [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
|
| 43 |
- [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
|
|
|
|
| 44 |
- [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
|
| 45 |
- [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
|
| 46 |
- [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
|
|
|
|
| 442 |
|
| 443 |
</details>
|
| 444 |
|
| 445 |
+
### Token Usage Tracking
|
| 446 |
+
|
| 447 |
+
<details>
|
| 448 |
+
<summary> <b>Overview and Usage</b> </summary>
|
| 449 |
+
|
| 450 |
+
LightRAG provides a TokenTracker tool to monitor and manage token consumption by large language models. This feature is particularly useful for controlling API costs and optimizing performance.
|
| 451 |
+
|
| 452 |
+
#### Usage
|
| 453 |
+
|
| 454 |
+
```python
|
| 455 |
+
from lightrag.utils import TokenTracker
|
| 456 |
+
|
| 457 |
+
# Create TokenTracker instance
|
| 458 |
+
token_tracker = TokenTracker()
|
| 459 |
+
|
| 460 |
+
# Method 1: Using context manager (Recommended)
|
| 461 |
+
# Suitable for scenarios requiring automatic token usage tracking
|
| 462 |
+
with token_tracker:
|
| 463 |
+
result1 = await llm_model_func("your question 1")
|
| 464 |
+
result2 = await llm_model_func("your question 2")
|
| 465 |
+
|
| 466 |
+
# Method 2: Manually adding token usage records
|
| 467 |
+
# Suitable for scenarios requiring more granular control over token statistics
|
| 468 |
+
token_tracker.reset()
|
| 469 |
+
|
| 470 |
+
rag.insert()
|
| 471 |
+
|
| 472 |
+
rag.query("your question 1", param=QueryParam(mode="naive"))
|
| 473 |
+
rag.query("your question 2", param=QueryParam(mode="mix"))
|
| 474 |
+
|
| 475 |
+
# Display total token usage (including insert and query operations)
|
| 476 |
+
print("Token usage:", token_tracker.get_usage())
|
| 477 |
+
```
|
| 478 |
+
|
| 479 |
+
#### Usage Tips
|
| 480 |
+
- Use context managers for long sessions or batch operations to automatically track all token consumption
|
| 481 |
+
- For scenarios requiring segmented statistics, use manual mode and call reset() when appropriate
|
| 482 |
+
- Regular checking of token usage helps detect abnormal consumption early
|
| 483 |
+
- Actively use this feature during development and testing to optimize production costs
|
| 484 |
+
|
| 485 |
+
#### Practical Examples
|
| 486 |
+
You can refer to these examples for implementing token tracking:
|
| 487 |
+
- `examples/lightrag_gemini_track_token_demo.py`: Token tracking example using Google Gemini model
|
| 488 |
+
- `examples/lightrag_siliconcloud_track_token_demo.py`: Token tracking example using SiliconCloud model
|
| 489 |
+
|
| 490 |
+
These examples demonstrate how to effectively use the TokenTracker feature with different models and scenarios.
|
| 491 |
+
|
| 492 |
+
</details>
|
| 493 |
+
|
| 494 |
### Conversation History Support
|
| 495 |
|
| 496 |
|
|
|
|
| 655 |
|
| 656 |
This is useful for scenarios where you want to process documents in the background while still allowing the main thread to continue executing.
|
| 657 |
|
| 658 |
+
And using a routine to process new documents.
|
| 659 |
|
| 660 |
```python
|
| 661 |
rag = LightRAG(..)
|
|
|
|
| 1144 |
| **Parameter** | **Type** | **Explanation** | **Default** |
|
| 1145 |
|--------------|----------|-----------------|-------------|
|
| 1146 |
| **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
|
| 1147 |
+
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
|
| 1148 |
+
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
|
| 1149 |
+
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
|
| 1150 |
+
| **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
|
| 1151 |
| **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
|
| 1152 |
| **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
|
| 1153 |
| **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
|
config.ini.example
CHANGED
|
@@ -13,23 +13,6 @@ uri=redis://localhost:6379/1
|
|
| 13 |
[qdrant]
|
| 14 |
uri = http://localhost:16333
|
| 15 |
|
| 16 |
-
[oracle]
|
| 17 |
-
dsn = localhost:1521/XEPDB1
|
| 18 |
-
user = your_username
|
| 19 |
-
password = your_password
|
| 20 |
-
config_dir = /path/to/oracle/config
|
| 21 |
-
wallet_location = /path/to/wallet # 可选
|
| 22 |
-
wallet_password = your_wallet_password # 可选
|
| 23 |
-
workspace = default # 可选,默认为default
|
| 24 |
-
|
| 25 |
-
[tidb]
|
| 26 |
-
host = localhost
|
| 27 |
-
port = 4000
|
| 28 |
-
user = your_username
|
| 29 |
-
password = your_password
|
| 30 |
-
database = your_database
|
| 31 |
-
workspace = default # 可选,默认为default
|
| 32 |
-
|
| 33 |
[postgres]
|
| 34 |
host = localhost
|
| 35 |
port = 5432
|
|
|
|
| 13 |
[qdrant]
|
| 14 |
uri = http://localhost:16333
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
[postgres]
|
| 17 |
host = localhost
|
| 18 |
port = 5432
|
env.example
CHANGED
|
@@ -4,11 +4,9 @@
|
|
| 4 |
# HOST=0.0.0.0
|
| 5 |
# PORT=9621
|
| 6 |
# WORKERS=2
|
| 7 |
-
### separating data from difference Lightrag instances
|
| 8 |
-
# NAMESPACE_PREFIX=lightrag
|
| 9 |
-
### Max nodes return from grap retrieval
|
| 10 |
-
# MAX_GRAPH_NODES=1000
|
| 11 |
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
|
|
|
|
|
|
| 12 |
|
| 13 |
### Optional SSL Configuration
|
| 14 |
# SSL=true
|
|
@@ -22,6 +20,9 @@
|
|
| 22 |
### Ollama Emulating Model Tag
|
| 23 |
# OLLAMA_EMULATING_MODEL_TAG=latest
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
### Logging level
|
| 26 |
# LOG_LEVEL=INFO
|
| 27 |
# VERBOSE=False
|
|
@@ -110,24 +111,14 @@ LIGHTRAG_VECTOR_STORAGE=NanoVectorDBStorage
|
|
| 110 |
LIGHTRAG_GRAPH_STORAGE=NetworkXStorage
|
| 111 |
LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage
|
| 112 |
|
| 113 |
-
###
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
#
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
#ORACLE_WORKSPACE=default
|
| 122 |
-
|
| 123 |
-
### TiDB Configuration
|
| 124 |
-
TIDB_HOST=localhost
|
| 125 |
-
TIDB_PORT=4000
|
| 126 |
-
TIDB_USER=your_username
|
| 127 |
-
TIDB_PASSWORD='your_password'
|
| 128 |
-
TIDB_DATABASE=your_database
|
| 129 |
-
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
|
| 130 |
-
#TIDB_WORKSPACE=default
|
| 131 |
|
| 132 |
### PostgreSQL Configuration
|
| 133 |
POSTGRES_HOST=localhost
|
|
@@ -135,8 +126,8 @@ POSTGRES_PORT=5432
|
|
| 135 |
POSTGRES_USER=your_username
|
| 136 |
POSTGRES_PASSWORD='your_password'
|
| 137 |
POSTGRES_DATABASE=your_database
|
| 138 |
-
### separating all data from difference Lightrag instances(deprecating
|
| 139 |
-
#POSTGRES_WORKSPACE=default
|
| 140 |
|
| 141 |
### Independent AGM Configuration(not for AMG embedded in PostreSQL)
|
| 142 |
AGE_POSTGRES_DB=
|
|
@@ -145,8 +136,8 @@ AGE_POSTGRES_PASSWORD=
|
|
| 145 |
AGE_POSTGRES_HOST=
|
| 146 |
# AGE_POSTGRES_PORT=8529
|
| 147 |
|
| 148 |
-
### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
|
| 149 |
# AGE Graph Name(apply to PostgreSQL and independent AGM)
|
|
|
|
| 150 |
# AGE_GRAPH_NAME=lightrag
|
| 151 |
|
| 152 |
### Neo4j Configuration
|
|
@@ -157,7 +148,7 @@ NEO4J_PASSWORD='your_password'
|
|
| 157 |
### MongoDB Configuration
|
| 158 |
MONGO_URI=mongodb://root:root@localhost:27017/
|
| 159 |
MONGO_DATABASE=LightRAG
|
| 160 |
-
### separating all data from difference Lightrag instances(deprecating
|
| 161 |
# MONGODB_GRAPH=false
|
| 162 |
|
| 163 |
### Milvus Configuration
|
|
@@ -177,7 +168,9 @@ REDIS_URI=redis://localhost:6379
|
|
| 177 |
### For JWT Auth
|
| 178 |
# AUTH_ACCOUNTS='admin:admin123,user1:pass456'
|
| 179 |
# TOKEN_SECRET=Your-Key-For-LightRAG-API-Server
|
| 180 |
-
# TOKEN_EXPIRE_HOURS=
|
|
|
|
|
|
|
| 181 |
|
| 182 |
### API-Key to access LightRAG Server API
|
| 183 |
# LIGHTRAG_API_KEY=your-secure-api-key-here
|
|
|
|
| 4 |
# HOST=0.0.0.0
|
| 5 |
# PORT=9621
|
| 6 |
# WORKERS=2
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
| 8 |
+
WEBUI_TITLE='Graph RAG Engine'
|
| 9 |
+
WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
|
| 10 |
|
| 11 |
### Optional SSL Configuration
|
| 12 |
# SSL=true
|
|
|
|
| 20 |
### Ollama Emulating Model Tag
|
| 21 |
# OLLAMA_EMULATING_MODEL_TAG=latest
|
| 22 |
|
| 23 |
+
### Max nodes return from grap retrieval
|
| 24 |
+
# MAX_GRAPH_NODES=1000
|
| 25 |
+
|
| 26 |
### Logging level
|
| 27 |
# LOG_LEVEL=INFO
|
| 28 |
# VERBOSE=False
|
|
|
|
| 111 |
LIGHTRAG_GRAPH_STORAGE=NetworkXStorage
|
| 112 |
LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage
|
| 113 |
|
| 114 |
+
### TiDB Configuration (Deprecated)
|
| 115 |
+
# TIDB_HOST=localhost
|
| 116 |
+
# TIDB_PORT=4000
|
| 117 |
+
# TIDB_USER=your_username
|
| 118 |
+
# TIDB_PASSWORD='your_password'
|
| 119 |
+
# TIDB_DATABASE=your_database
|
| 120 |
+
### separating all data from difference Lightrag instances(deprecating)
|
| 121 |
+
# TIDB_WORKSPACE=default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
### PostgreSQL Configuration
|
| 124 |
POSTGRES_HOST=localhost
|
|
|
|
| 126 |
POSTGRES_USER=your_username
|
| 127 |
POSTGRES_PASSWORD='your_password'
|
| 128 |
POSTGRES_DATABASE=your_database
|
| 129 |
+
### separating all data from difference Lightrag instances(deprecating)
|
| 130 |
+
# POSTGRES_WORKSPACE=default
|
| 131 |
|
| 132 |
### Independent AGM Configuration(not for AMG embedded in PostreSQL)
|
| 133 |
AGE_POSTGRES_DB=
|
|
|
|
| 136 |
AGE_POSTGRES_HOST=
|
| 137 |
# AGE_POSTGRES_PORT=8529
|
| 138 |
|
|
|
|
| 139 |
# AGE Graph Name(apply to PostgreSQL and independent AGM)
|
| 140 |
+
### AGE_GRAPH_NAME is precated
|
| 141 |
# AGE_GRAPH_NAME=lightrag
|
| 142 |
|
| 143 |
### Neo4j Configuration
|
|
|
|
| 148 |
### MongoDB Configuration
|
| 149 |
MONGO_URI=mongodb://root:root@localhost:27017/
|
| 150 |
MONGO_DATABASE=LightRAG
|
| 151 |
+
### separating all data from difference Lightrag instances(deprecating)
|
| 152 |
# MONGODB_GRAPH=false
|
| 153 |
|
| 154 |
### Milvus Configuration
|
|
|
|
| 168 |
### For JWT Auth
|
| 169 |
# AUTH_ACCOUNTS='admin:admin123,user1:pass456'
|
| 170 |
# TOKEN_SECRET=Your-Key-For-LightRAG-API-Server
|
| 171 |
+
# TOKEN_EXPIRE_HOURS=48
|
| 172 |
+
# GUEST_TOKEN_EXPIRE_HOURS=24
|
| 173 |
+
# JWT_ALGORITHM=HS256
|
| 174 |
|
| 175 |
### API-Key to access LightRAG Server API
|
| 176 |
# LIGHTRAG_API_KEY=your-secure-api-key-here
|
examples/lightrag_api_ollama_demo.py
DELETED
|
@@ -1,188 +0,0 @@
|
|
| 1 |
-
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 2 |
-
from contextlib import asynccontextmanager
|
| 3 |
-
from pydantic import BaseModel
|
| 4 |
-
import os
|
| 5 |
-
from lightrag import LightRAG, QueryParam
|
| 6 |
-
from lightrag.llm.ollama import ollama_embed, ollama_model_complete
|
| 7 |
-
from lightrag.utils import EmbeddingFunc
|
| 8 |
-
from typing import Optional
|
| 9 |
-
import asyncio
|
| 10 |
-
import nest_asyncio
|
| 11 |
-
import aiofiles
|
| 12 |
-
from lightrag.kg.shared_storage import initialize_pipeline_status
|
| 13 |
-
|
| 14 |
-
# Apply nest_asyncio to solve event loop issues
|
| 15 |
-
nest_asyncio.apply()
|
| 16 |
-
|
| 17 |
-
DEFAULT_RAG_DIR = "index_default"
|
| 18 |
-
|
| 19 |
-
DEFAULT_INPUT_FILE = "book.txt"
|
| 20 |
-
INPUT_FILE = os.environ.get("INPUT_FILE", f"{DEFAULT_INPUT_FILE}")
|
| 21 |
-
print(f"INPUT_FILE: {INPUT_FILE}")
|
| 22 |
-
|
| 23 |
-
# Configure working directory
|
| 24 |
-
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
| 25 |
-
print(f"WORKING_DIR: {WORKING_DIR}")
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
if not os.path.exists(WORKING_DIR):
|
| 29 |
-
os.mkdir(WORKING_DIR)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
async def init():
|
| 33 |
-
rag = LightRAG(
|
| 34 |
-
working_dir=WORKING_DIR,
|
| 35 |
-
llm_model_func=ollama_model_complete,
|
| 36 |
-
llm_model_name="gemma2:9b",
|
| 37 |
-
llm_model_max_async=4,
|
| 38 |
-
llm_model_max_token_size=8192,
|
| 39 |
-
llm_model_kwargs={
|
| 40 |
-
"host": "http://localhost:11434",
|
| 41 |
-
"options": {"num_ctx": 8192},
|
| 42 |
-
},
|
| 43 |
-
embedding_func=EmbeddingFunc(
|
| 44 |
-
embedding_dim=768,
|
| 45 |
-
max_token_size=8192,
|
| 46 |
-
func=lambda texts: ollama_embed(
|
| 47 |
-
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
|
| 48 |
-
),
|
| 49 |
-
),
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
# Add initialization code
|
| 53 |
-
await rag.initialize_storages()
|
| 54 |
-
await initialize_pipeline_status()
|
| 55 |
-
|
| 56 |
-
return rag
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
@asynccontextmanager
|
| 60 |
-
async def lifespan(app: FastAPI):
|
| 61 |
-
global rag
|
| 62 |
-
rag = await init()
|
| 63 |
-
print("done!")
|
| 64 |
-
yield
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
app = FastAPI(
|
| 68 |
-
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
# Data models
|
| 73 |
-
class QueryRequest(BaseModel):
|
| 74 |
-
query: str
|
| 75 |
-
mode: str = "hybrid"
|
| 76 |
-
only_need_context: bool = False
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
class InsertRequest(BaseModel):
|
| 80 |
-
text: str
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
class Response(BaseModel):
|
| 84 |
-
status: str
|
| 85 |
-
data: Optional[str] = None
|
| 86 |
-
message: Optional[str] = None
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
# API routes
|
| 90 |
-
@app.post("/query", response_model=Response)
|
| 91 |
-
async def query_endpoint(request: QueryRequest):
|
| 92 |
-
try:
|
| 93 |
-
loop = asyncio.get_event_loop()
|
| 94 |
-
result = await loop.run_in_executor(
|
| 95 |
-
None,
|
| 96 |
-
lambda: rag.query(
|
| 97 |
-
request.query,
|
| 98 |
-
param=QueryParam(
|
| 99 |
-
mode=request.mode, only_need_context=request.only_need_context
|
| 100 |
-
),
|
| 101 |
-
),
|
| 102 |
-
)
|
| 103 |
-
return Response(status="success", data=result)
|
| 104 |
-
except Exception as e:
|
| 105 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# insert by text
|
| 109 |
-
@app.post("/insert", response_model=Response)
|
| 110 |
-
async def insert_endpoint(request: InsertRequest):
|
| 111 |
-
try:
|
| 112 |
-
loop = asyncio.get_event_loop()
|
| 113 |
-
await loop.run_in_executor(None, lambda: rag.insert(request.text))
|
| 114 |
-
return Response(status="success", message="Text inserted successfully")
|
| 115 |
-
except Exception as e:
|
| 116 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# insert by file in payload
|
| 120 |
-
@app.post("/insert_file", response_model=Response)
|
| 121 |
-
async def insert_file(file: UploadFile = File(...)):
|
| 122 |
-
try:
|
| 123 |
-
file_content = await file.read()
|
| 124 |
-
# Read file content
|
| 125 |
-
try:
|
| 126 |
-
content = file_content.decode("utf-8")
|
| 127 |
-
except UnicodeDecodeError:
|
| 128 |
-
# If UTF-8 decoding fails, try other encodings
|
| 129 |
-
content = file_content.decode("gbk")
|
| 130 |
-
# Insert file content
|
| 131 |
-
loop = asyncio.get_event_loop()
|
| 132 |
-
await loop.run_in_executor(None, lambda: rag.insert(content))
|
| 133 |
-
|
| 134 |
-
return Response(
|
| 135 |
-
status="success",
|
| 136 |
-
message=f"File content from {file.filename} inserted successfully",
|
| 137 |
-
)
|
| 138 |
-
except Exception as e:
|
| 139 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
# insert by local default file
|
| 143 |
-
@app.post("/insert_default_file", response_model=Response)
|
| 144 |
-
@app.get("/insert_default_file", response_model=Response)
|
| 145 |
-
async def insert_default_file():
|
| 146 |
-
try:
|
| 147 |
-
# Read file content from book.txt
|
| 148 |
-
async with aiofiles.open(INPUT_FILE, "r", encoding="utf-8") as file:
|
| 149 |
-
content = await file.read()
|
| 150 |
-
print(f"read input file {INPUT_FILE} successfully")
|
| 151 |
-
# Insert file content
|
| 152 |
-
loop = asyncio.get_event_loop()
|
| 153 |
-
await loop.run_in_executor(None, lambda: rag.insert(content))
|
| 154 |
-
|
| 155 |
-
return Response(
|
| 156 |
-
status="success",
|
| 157 |
-
message=f"File content from {INPUT_FILE} inserted successfully",
|
| 158 |
-
)
|
| 159 |
-
except Exception as e:
|
| 160 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
@app.get("/health")
|
| 164 |
-
async def health_check():
|
| 165 |
-
return {"status": "healthy"}
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
if __name__ == "__main__":
|
| 169 |
-
import uvicorn
|
| 170 |
-
|
| 171 |
-
uvicorn.run(app, host="0.0.0.0", port=8020)
|
| 172 |
-
|
| 173 |
-
# Usage example
|
| 174 |
-
# To run the server, use the following command in your terminal:
|
| 175 |
-
# python lightrag_api_openai_compatible_demo.py
|
| 176 |
-
|
| 177 |
-
# Example requests:
|
| 178 |
-
# 1. Query:
|
| 179 |
-
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
|
| 180 |
-
|
| 181 |
-
# 2. Insert text:
|
| 182 |
-
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
|
| 183 |
-
|
| 184 |
-
# 3. Insert file:
|
| 185 |
-
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
|
| 186 |
-
|
| 187 |
-
# 4. Health check:
|
| 188 |
-
# curl -X GET "http://127.0.0.1:8020/health"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/lightrag_api_openai_compatible_demo.py
DELETED
|
@@ -1,204 +0,0 @@
|
|
| 1 |
-
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 2 |
-
from contextlib import asynccontextmanager
|
| 3 |
-
from pydantic import BaseModel
|
| 4 |
-
import os
|
| 5 |
-
from lightrag import LightRAG, QueryParam
|
| 6 |
-
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
| 7 |
-
from lightrag.utils import EmbeddingFunc
|
| 8 |
-
import numpy as np
|
| 9 |
-
from typing import Optional
|
| 10 |
-
import asyncio
|
| 11 |
-
import nest_asyncio
|
| 12 |
-
from lightrag.kg.shared_storage import initialize_pipeline_status
|
| 13 |
-
|
| 14 |
-
# Apply nest_asyncio to solve event loop issues
|
| 15 |
-
nest_asyncio.apply()
|
| 16 |
-
|
| 17 |
-
DEFAULT_RAG_DIR = "index_default"
|
| 18 |
-
app = FastAPI(title="LightRAG API", description="API for RAG operations")
|
| 19 |
-
|
| 20 |
-
# Configure working directory
|
| 21 |
-
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
| 22 |
-
print(f"WORKING_DIR: {WORKING_DIR}")
|
| 23 |
-
LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
|
| 24 |
-
print(f"LLM_MODEL: {LLM_MODEL}")
|
| 25 |
-
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
|
| 26 |
-
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
| 27 |
-
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
|
| 28 |
-
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
| 29 |
-
BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
|
| 30 |
-
print(f"BASE_URL: {BASE_URL}")
|
| 31 |
-
API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
|
| 32 |
-
print(f"API_KEY: {API_KEY}")
|
| 33 |
-
|
| 34 |
-
if not os.path.exists(WORKING_DIR):
|
| 35 |
-
os.mkdir(WORKING_DIR)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# LLM model function
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
async def llm_model_func(
|
| 42 |
-
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 43 |
-
) -> str:
|
| 44 |
-
return await openai_complete_if_cache(
|
| 45 |
-
model=LLM_MODEL,
|
| 46 |
-
prompt=prompt,
|
| 47 |
-
system_prompt=system_prompt,
|
| 48 |
-
history_messages=history_messages,
|
| 49 |
-
base_url=BASE_URL,
|
| 50 |
-
api_key=API_KEY,
|
| 51 |
-
**kwargs,
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
# Embedding function
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
async def embedding_func(texts: list[str]) -> np.ndarray:
|
| 59 |
-
return await openai_embed(
|
| 60 |
-
texts=texts,
|
| 61 |
-
model=EMBEDDING_MODEL,
|
| 62 |
-
base_url=BASE_URL,
|
| 63 |
-
api_key=API_KEY,
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
async def get_embedding_dim():
|
| 68 |
-
test_text = ["This is a test sentence."]
|
| 69 |
-
embedding = await embedding_func(test_text)
|
| 70 |
-
embedding_dim = embedding.shape[1]
|
| 71 |
-
print(f"{embedding_dim=}")
|
| 72 |
-
return embedding_dim
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
# Initialize RAG instance
|
| 76 |
-
async def init():
|
| 77 |
-
embedding_dimension = await get_embedding_dim()
|
| 78 |
-
|
| 79 |
-
rag = LightRAG(
|
| 80 |
-
working_dir=WORKING_DIR,
|
| 81 |
-
llm_model_func=llm_model_func,
|
| 82 |
-
embedding_func=EmbeddingFunc(
|
| 83 |
-
embedding_dim=embedding_dimension,
|
| 84 |
-
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
|
| 85 |
-
func=embedding_func,
|
| 86 |
-
),
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
await rag.initialize_storages()
|
| 90 |
-
await initialize_pipeline_status()
|
| 91 |
-
|
| 92 |
-
return rag
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
@asynccontextmanager
|
| 96 |
-
async def lifespan(app: FastAPI):
|
| 97 |
-
global rag
|
| 98 |
-
rag = await init()
|
| 99 |
-
print("done!")
|
| 100 |
-
yield
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
app = FastAPI(
|
| 104 |
-
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
# Data models
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
class QueryRequest(BaseModel):
|
| 111 |
-
query: str
|
| 112 |
-
mode: str = "hybrid"
|
| 113 |
-
only_need_context: bool = False
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
class InsertRequest(BaseModel):
|
| 117 |
-
text: str
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
class Response(BaseModel):
|
| 121 |
-
status: str
|
| 122 |
-
data: Optional[str] = None
|
| 123 |
-
message: Optional[str] = None
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
# API routes
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
@app.post("/query", response_model=Response)
|
| 130 |
-
async def query_endpoint(request: QueryRequest):
|
| 131 |
-
try:
|
| 132 |
-
loop = asyncio.get_event_loop()
|
| 133 |
-
result = await loop.run_in_executor(
|
| 134 |
-
None,
|
| 135 |
-
lambda: rag.query(
|
| 136 |
-
request.query,
|
| 137 |
-
param=QueryParam(
|
| 138 |
-
mode=request.mode, only_need_context=request.only_need_context
|
| 139 |
-
),
|
| 140 |
-
),
|
| 141 |
-
)
|
| 142 |
-
return Response(status="success", data=result)
|
| 143 |
-
except Exception as e:
|
| 144 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
@app.post("/insert", response_model=Response)
|
| 148 |
-
async def insert_endpoint(request: InsertRequest):
|
| 149 |
-
try:
|
| 150 |
-
loop = asyncio.get_event_loop()
|
| 151 |
-
await loop.run_in_executor(None, lambda: rag.insert(request.text))
|
| 152 |
-
return Response(status="success", message="Text inserted successfully")
|
| 153 |
-
except Exception as e:
|
| 154 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
@app.post("/insert_file", response_model=Response)
|
| 158 |
-
async def insert_file(file: UploadFile = File(...)):
|
| 159 |
-
try:
|
| 160 |
-
file_content = await file.read()
|
| 161 |
-
# Read file content
|
| 162 |
-
try:
|
| 163 |
-
content = file_content.decode("utf-8")
|
| 164 |
-
except UnicodeDecodeError:
|
| 165 |
-
# If UTF-8 decoding fails, try other encodings
|
| 166 |
-
content = file_content.decode("gbk")
|
| 167 |
-
# Insert file content
|
| 168 |
-
loop = asyncio.get_event_loop()
|
| 169 |
-
await loop.run_in_executor(None, lambda: rag.insert(content))
|
| 170 |
-
|
| 171 |
-
return Response(
|
| 172 |
-
status="success",
|
| 173 |
-
message=f"File content from {file.filename} inserted successfully",
|
| 174 |
-
)
|
| 175 |
-
except Exception as e:
|
| 176 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
@app.get("/health")
|
| 180 |
-
async def health_check():
|
| 181 |
-
return {"status": "healthy"}
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
if __name__ == "__main__":
|
| 185 |
-
import uvicorn
|
| 186 |
-
|
| 187 |
-
uvicorn.run(app, host="0.0.0.0", port=8020)
|
| 188 |
-
|
| 189 |
-
# Usage example
|
| 190 |
-
# To run the server, use the following command in your terminal:
|
| 191 |
-
# python lightrag_api_openai_compatible_demo.py
|
| 192 |
-
|
| 193 |
-
# Example requests:
|
| 194 |
-
# 1. Query:
|
| 195 |
-
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
|
| 196 |
-
|
| 197 |
-
# 2. Insert text:
|
| 198 |
-
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
|
| 199 |
-
|
| 200 |
-
# 3. Insert file:
|
| 201 |
-
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
|
| 202 |
-
|
| 203 |
-
# 4. Health check:
|
| 204 |
-
# curl -X GET "http://127.0.0.1:8020/health"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/lightrag_api_oracle_demo.py
DELETED
|
@@ -1,267 +0,0 @@
|
|
| 1 |
-
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 2 |
-
from fastapi import Query
|
| 3 |
-
from contextlib import asynccontextmanager
|
| 4 |
-
from pydantic import BaseModel
|
| 5 |
-
from typing import Optional, Any
|
| 6 |
-
|
| 7 |
-
import sys
|
| 8 |
-
import os
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
|
| 13 |
-
import asyncio
|
| 14 |
-
import nest_asyncio
|
| 15 |
-
from lightrag import LightRAG, QueryParam
|
| 16 |
-
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
| 17 |
-
from lightrag.utils import EmbeddingFunc
|
| 18 |
-
import numpy as np
|
| 19 |
-
from lightrag.kg.shared_storage import initialize_pipeline_status
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
print(os.getcwd())
|
| 23 |
-
script_directory = Path(__file__).resolve().parent.parent
|
| 24 |
-
sys.path.append(os.path.abspath(script_directory))
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
# Apply nest_asyncio to solve event loop issues
|
| 28 |
-
nest_asyncio.apply()
|
| 29 |
-
|
| 30 |
-
DEFAULT_RAG_DIR = "index_default"
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# We use OpenAI compatible API to call LLM on Oracle Cloud
|
| 34 |
-
# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
|
| 35 |
-
BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
|
| 36 |
-
APIKEY = "ocigenerativeai"
|
| 37 |
-
|
| 38 |
-
# Configure working directory
|
| 39 |
-
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
| 40 |
-
print(f"WORKING_DIR: {WORKING_DIR}")
|
| 41 |
-
LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024")
|
| 42 |
-
print(f"LLM_MODEL: {LLM_MODEL}")
|
| 43 |
-
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0")
|
| 44 |
-
print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
|
| 45 |
-
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512))
|
| 46 |
-
print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
|
| 47 |
-
|
| 48 |
-
if not os.path.exists(WORKING_DIR):
|
| 49 |
-
os.mkdir(WORKING_DIR)
|
| 50 |
-
|
| 51 |
-
os.environ["ORACLE_USER"] = ""
|
| 52 |
-
os.environ["ORACLE_PASSWORD"] = ""
|
| 53 |
-
os.environ["ORACLE_DSN"] = ""
|
| 54 |
-
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
|
| 55 |
-
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
|
| 56 |
-
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
|
| 57 |
-
os.environ["ORACLE_WORKSPACE"] = "company"
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
async def llm_model_func(
|
| 61 |
-
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 62 |
-
) -> str:
|
| 63 |
-
return await openai_complete_if_cache(
|
| 64 |
-
LLM_MODEL,
|
| 65 |
-
prompt,
|
| 66 |
-
system_prompt=system_prompt,
|
| 67 |
-
history_messages=history_messages,
|
| 68 |
-
api_key=APIKEY,
|
| 69 |
-
base_url=BASE_URL,
|
| 70 |
-
**kwargs,
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
async def embedding_func(texts: list[str]) -> np.ndarray:
|
| 75 |
-
return await openai_embed(
|
| 76 |
-
texts,
|
| 77 |
-
model=EMBEDDING_MODEL,
|
| 78 |
-
api_key=APIKEY,
|
| 79 |
-
base_url=BASE_URL,
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
async def get_embedding_dim():
|
| 84 |
-
test_text = ["This is a test sentence."]
|
| 85 |
-
embedding = await embedding_func(test_text)
|
| 86 |
-
embedding_dim = embedding.shape[1]
|
| 87 |
-
return embedding_dim
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
async def init():
|
| 91 |
-
# Detect embedding dimension
|
| 92 |
-
embedding_dimension = await get_embedding_dim()
|
| 93 |
-
print(f"Detected embedding dimension: {embedding_dimension}")
|
| 94 |
-
# Create Oracle DB connection
|
| 95 |
-
# The `config` parameter is the connection configuration of Oracle DB
|
| 96 |
-
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
|
| 97 |
-
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
|
| 98 |
-
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
|
| 99 |
-
|
| 100 |
-
# Initialize LightRAG
|
| 101 |
-
# We use Oracle DB as the KV/vector/graph storage
|
| 102 |
-
rag = LightRAG(
|
| 103 |
-
enable_llm_cache=False,
|
| 104 |
-
working_dir=WORKING_DIR,
|
| 105 |
-
chunk_token_size=512,
|
| 106 |
-
llm_model_func=llm_model_func,
|
| 107 |
-
embedding_func=EmbeddingFunc(
|
| 108 |
-
embedding_dim=embedding_dimension,
|
| 109 |
-
max_token_size=512,
|
| 110 |
-
func=embedding_func,
|
| 111 |
-
),
|
| 112 |
-
graph_storage="OracleGraphStorage",
|
| 113 |
-
kv_storage="OracleKVStorage",
|
| 114 |
-
vector_storage="OracleVectorDBStorage",
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
await rag.initialize_storages()
|
| 118 |
-
await initialize_pipeline_status()
|
| 119 |
-
|
| 120 |
-
return rag
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
# Extract and Insert into LightRAG storage
|
| 124 |
-
# with open("./dickens/book.txt", "r", encoding="utf-8") as f:
|
| 125 |
-
# await rag.ainsert(f.read())
|
| 126 |
-
|
| 127 |
-
# # Perform search in different modes
|
| 128 |
-
# modes = ["naive", "local", "global", "hybrid"]
|
| 129 |
-
# for mode in modes:
|
| 130 |
-
# print("="*20, mode, "="*20)
|
| 131 |
-
# print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode)))
|
| 132 |
-
# print("-"*100, "\n")
|
| 133 |
-
|
| 134 |
-
# Data models
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
class QueryRequest(BaseModel):
|
| 138 |
-
query: str
|
| 139 |
-
mode: str = "hybrid"
|
| 140 |
-
only_need_context: bool = False
|
| 141 |
-
only_need_prompt: bool = False
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
class DataRequest(BaseModel):
|
| 145 |
-
limit: int = 100
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
class InsertRequest(BaseModel):
|
| 149 |
-
text: str
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
class Response(BaseModel):
|
| 153 |
-
status: str
|
| 154 |
-
data: Optional[Any] = None
|
| 155 |
-
message: Optional[str] = None
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
# API routes
|
| 159 |
-
|
| 160 |
-
rag = None
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
@asynccontextmanager
|
| 164 |
-
async def lifespan(app: FastAPI):
|
| 165 |
-
global rag
|
| 166 |
-
rag = await init()
|
| 167 |
-
print("done!")
|
| 168 |
-
yield
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
app = FastAPI(
|
| 172 |
-
title="LightRAG API", description="API for RAG operations", lifespan=lifespan
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
@app.post("/query", response_model=Response)
|
| 177 |
-
async def query_endpoint(request: QueryRequest):
|
| 178 |
-
# try:
|
| 179 |
-
# loop = asyncio.get_event_loop()
|
| 180 |
-
if request.mode == "naive":
|
| 181 |
-
top_k = 3
|
| 182 |
-
else:
|
| 183 |
-
top_k = 60
|
| 184 |
-
result = await rag.aquery(
|
| 185 |
-
request.query,
|
| 186 |
-
param=QueryParam(
|
| 187 |
-
mode=request.mode,
|
| 188 |
-
only_need_context=request.only_need_context,
|
| 189 |
-
only_need_prompt=request.only_need_prompt,
|
| 190 |
-
top_k=top_k,
|
| 191 |
-
),
|
| 192 |
-
)
|
| 193 |
-
return Response(status="success", data=result)
|
| 194 |
-
# except Exception as e:
|
| 195 |
-
# raise HTTPException(status_code=500, detail=str(e))
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
@app.get("/data", response_model=Response)
|
| 199 |
-
async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
|
| 200 |
-
if type == "nodes":
|
| 201 |
-
result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit)
|
| 202 |
-
elif type == "edges":
|
| 203 |
-
result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit)
|
| 204 |
-
elif type == "statistics":
|
| 205 |
-
result = await rag.chunk_entity_relation_graph.get_statistics()
|
| 206 |
-
return Response(status="success", data=result)
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
@app.post("/insert", response_model=Response)
|
| 210 |
-
async def insert_endpoint(request: InsertRequest):
|
| 211 |
-
try:
|
| 212 |
-
loop = asyncio.get_event_loop()
|
| 213 |
-
await loop.run_in_executor(None, lambda: rag.insert(request.text))
|
| 214 |
-
return Response(status="success", message="Text inserted successfully")
|
| 215 |
-
except Exception as e:
|
| 216 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
@app.post("/insert_file", response_model=Response)
|
| 220 |
-
async def insert_file(file: UploadFile = File(...)):
|
| 221 |
-
try:
|
| 222 |
-
file_content = await file.read()
|
| 223 |
-
# Read file content
|
| 224 |
-
try:
|
| 225 |
-
content = file_content.decode("utf-8")
|
| 226 |
-
except UnicodeDecodeError:
|
| 227 |
-
# If UTF-8 decoding fails, try other encodings
|
| 228 |
-
content = file_content.decode("gbk")
|
| 229 |
-
# Insert file content
|
| 230 |
-
loop = asyncio.get_event_loop()
|
| 231 |
-
await loop.run_in_executor(None, lambda: rag.insert(content))
|
| 232 |
-
|
| 233 |
-
return Response(
|
| 234 |
-
status="success",
|
| 235 |
-
message=f"File content from {file.filename} inserted successfully",
|
| 236 |
-
)
|
| 237 |
-
except Exception as e:
|
| 238 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
@app.get("/health")
|
| 242 |
-
async def health_check():
|
| 243 |
-
return {"status": "healthy"}
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
if __name__ == "__main__":
|
| 247 |
-
import uvicorn
|
| 248 |
-
|
| 249 |
-
uvicorn.run(app, host="127.0.0.1", port=8020)
|
| 250 |
-
|
| 251 |
-
# Usage example
|
| 252 |
-
# To run the server, use the following command in your terminal:
|
| 253 |
-
# python lightrag_api_openai_compatible_demo.py
|
| 254 |
-
|
| 255 |
-
# Example requests:
|
| 256 |
-
# 1. Query:
|
| 257 |
-
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
|
| 258 |
-
|
| 259 |
-
# 2. Insert text:
|
| 260 |
-
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
|
| 261 |
-
|
| 262 |
-
# 3. Insert file:
|
| 263 |
-
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
# 4. Health check:
|
| 267 |
-
# curl -X GET "http://127.0.0.1:8020/health"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/lightrag_ollama_gremlin_demo.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import inspect
|
| 3 |
import os
|
|
|
|
| 1 |
+
##############################################
|
| 2 |
+
# Gremlin storage implementation is deprecated
|
| 3 |
+
##############################################
|
| 4 |
+
|
| 5 |
import asyncio
|
| 6 |
import inspect
|
| 7 |
import os
|
examples/lightrag_oracle_demo.py
DELETED
|
@@ -1,141 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import os
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
import asyncio
|
| 5 |
-
from lightrag import LightRAG, QueryParam
|
| 6 |
-
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
| 7 |
-
from lightrag.utils import EmbeddingFunc
|
| 8 |
-
import numpy as np
|
| 9 |
-
from lightrag.kg.shared_storage import initialize_pipeline_status
|
| 10 |
-
|
| 11 |
-
print(os.getcwd())
|
| 12 |
-
script_directory = Path(__file__).resolve().parent.parent
|
| 13 |
-
sys.path.append(os.path.abspath(script_directory))
|
| 14 |
-
|
| 15 |
-
WORKING_DIR = "./dickens"
|
| 16 |
-
|
| 17 |
-
# We use OpenAI compatible API to call LLM on Oracle Cloud
|
| 18 |
-
# More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
|
| 19 |
-
BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
|
| 20 |
-
APIKEY = "ocigenerativeai"
|
| 21 |
-
CHATMODEL = "cohere.command-r-plus"
|
| 22 |
-
EMBEDMODEL = "cohere.embed-multilingual-v3.0"
|
| 23 |
-
CHUNK_TOKEN_SIZE = 1024
|
| 24 |
-
MAX_TOKENS = 4000
|
| 25 |
-
|
| 26 |
-
if not os.path.exists(WORKING_DIR):
|
| 27 |
-
os.mkdir(WORKING_DIR)
|
| 28 |
-
|
| 29 |
-
os.environ["ORACLE_USER"] = "username"
|
| 30 |
-
os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx"
|
| 31 |
-
os.environ["ORACLE_DSN"] = "xxxxxxx_medium"
|
| 32 |
-
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
|
| 33 |
-
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
|
| 34 |
-
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
|
| 35 |
-
os.environ["ORACLE_WORKSPACE"] = "company"
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
async def llm_model_func(
|
| 39 |
-
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 40 |
-
) -> str:
|
| 41 |
-
return await openai_complete_if_cache(
|
| 42 |
-
CHATMODEL,
|
| 43 |
-
prompt,
|
| 44 |
-
system_prompt=system_prompt,
|
| 45 |
-
history_messages=history_messages,
|
| 46 |
-
api_key=APIKEY,
|
| 47 |
-
base_url=BASE_URL,
|
| 48 |
-
**kwargs,
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
async def embedding_func(texts: list[str]) -> np.ndarray:
|
| 53 |
-
return await openai_embed(
|
| 54 |
-
texts,
|
| 55 |
-
model=EMBEDMODEL,
|
| 56 |
-
api_key=APIKEY,
|
| 57 |
-
base_url=BASE_URL,
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
async def get_embedding_dim():
|
| 62 |
-
test_text = ["This is a test sentence."]
|
| 63 |
-
embedding = await embedding_func(test_text)
|
| 64 |
-
embedding_dim = embedding.shape[1]
|
| 65 |
-
return embedding_dim
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
async def initialize_rag():
|
| 69 |
-
# Detect embedding dimension
|
| 70 |
-
embedding_dimension = await get_embedding_dim()
|
| 71 |
-
print(f"Detected embedding dimension: {embedding_dimension}")
|
| 72 |
-
|
| 73 |
-
# Initialize LightRAG
|
| 74 |
-
# We use Oracle DB as the KV/vector/graph storage
|
| 75 |
-
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
| 76 |
-
rag = LightRAG(
|
| 77 |
-
# log_level="DEBUG",
|
| 78 |
-
working_dir=WORKING_DIR,
|
| 79 |
-
entity_extract_max_gleaning=1,
|
| 80 |
-
enable_llm_cache=True,
|
| 81 |
-
enable_llm_cache_for_entity_extract=True,
|
| 82 |
-
embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
|
| 83 |
-
chunk_token_size=CHUNK_TOKEN_SIZE,
|
| 84 |
-
llm_model_max_token_size=MAX_TOKENS,
|
| 85 |
-
llm_model_func=llm_model_func,
|
| 86 |
-
embedding_func=EmbeddingFunc(
|
| 87 |
-
embedding_dim=embedding_dimension,
|
| 88 |
-
max_token_size=500,
|
| 89 |
-
func=embedding_func,
|
| 90 |
-
),
|
| 91 |
-
graph_storage="OracleGraphStorage",
|
| 92 |
-
kv_storage="OracleKVStorage",
|
| 93 |
-
vector_storage="OracleVectorDBStorage",
|
| 94 |
-
addon_params={
|
| 95 |
-
"example_number": 1,
|
| 96 |
-
"language": "Simplfied Chinese",
|
| 97 |
-
"entity_types": ["organization", "person", "geo", "event"],
|
| 98 |
-
"insert_batch_size": 2,
|
| 99 |
-
},
|
| 100 |
-
)
|
| 101 |
-
await rag.initialize_storages()
|
| 102 |
-
await initialize_pipeline_status()
|
| 103 |
-
|
| 104 |
-
return rag
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
async def main():
|
| 108 |
-
try:
|
| 109 |
-
# Initialize RAG instance
|
| 110 |
-
rag = await initialize_rag()
|
| 111 |
-
|
| 112 |
-
# Extract and Insert into LightRAG storage
|
| 113 |
-
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
| 114 |
-
all_text = f.read()
|
| 115 |
-
texts = [x for x in all_text.split("\n") if x]
|
| 116 |
-
|
| 117 |
-
# New mode use pipeline
|
| 118 |
-
await rag.apipeline_enqueue_documents(texts)
|
| 119 |
-
await rag.apipeline_process_enqueue_documents()
|
| 120 |
-
|
| 121 |
-
# Old method use ainsert
|
| 122 |
-
# await rag.ainsert(texts)
|
| 123 |
-
|
| 124 |
-
# Perform search in different modes
|
| 125 |
-
modes = ["naive", "local", "global", "hybrid"]
|
| 126 |
-
for mode in modes:
|
| 127 |
-
print("=" * 20, mode, "=" * 20)
|
| 128 |
-
print(
|
| 129 |
-
await rag.aquery(
|
| 130 |
-
"What are the top themes in this story?",
|
| 131 |
-
param=QueryParam(mode=mode),
|
| 132 |
-
)
|
| 133 |
-
)
|
| 134 |
-
print("-" * 100, "\n")
|
| 135 |
-
|
| 136 |
-
except Exception as e:
|
| 137 |
-
print(f"An error occurred: {e}")
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
if __name__ == "__main__":
|
| 141 |
-
asyncio.run(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/lightrag_tidb_demo.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
|
|
|
|
| 1 |
+
###########################################
|
| 2 |
+
# TiDB storage implementation is deprecated
|
| 3 |
+
###########################################
|
| 4 |
+
|
| 5 |
import asyncio
|
| 6 |
import os
|
| 7 |
|
lightrag/api/README-zh.md
CHANGED
|
@@ -291,11 +291,9 @@ LightRAG 使用 4 种类型的存储用于不同目的:
|
|
| 291 |
|
| 292 |
```
|
| 293 |
JsonKVStorage JsonFile(默认)
|
| 294 |
-
MongoKVStorage MogonDB
|
| 295 |
-
RedisKVStorage Redis
|
| 296 |
-
TiDBKVStorage TiDB
|
| 297 |
PGKVStorage Postgres
|
| 298 |
-
|
|
|
|
| 299 |
```
|
| 300 |
|
| 301 |
* GRAPH_STORAGE 支持的实现名称
|
|
@@ -303,25 +301,19 @@ OracleKVStorage Oracle
|
|
| 303 |
```
|
| 304 |
NetworkXStorage NetworkX(默认)
|
| 305 |
Neo4JStorage Neo4J
|
| 306 |
-
MongoGraphStorage MongoDB
|
| 307 |
-
TiDBGraphStorage TiDB
|
| 308 |
-
AGEStorage AGE
|
| 309 |
-
GremlinStorage Gremlin
|
| 310 |
PGGraphStorage Postgres
|
| 311 |
-
|
| 312 |
```
|
| 313 |
|
| 314 |
* VECTOR_STORAGE 支持的实现名称
|
| 315 |
|
| 316 |
```
|
| 317 |
NanoVectorDBStorage NanoVector(默认)
|
|
|
|
| 318 |
MilvusVectorDBStorge Milvus
|
| 319 |
ChromaVectorDBStorage Chroma
|
| 320 |
-
TiDBVectorDBStorage TiDB
|
| 321 |
-
PGVectorStorage Postgres
|
| 322 |
FaissVectorDBStorage Faiss
|
| 323 |
QdrantVectorDBStorage Qdrant
|
| 324 |
-
OracleVectorDBStorage Oracle
|
| 325 |
MongoVectorDBStorage MongoDB
|
| 326 |
```
|
| 327 |
|
|
|
|
| 291 |
|
| 292 |
```
|
| 293 |
JsonKVStorage JsonFile(默认)
|
|
|
|
|
|
|
|
|
|
| 294 |
PGKVStorage Postgres
|
| 295 |
+
RedisKVStorage Redis
|
| 296 |
+
MongoKVStorage MogonDB
|
| 297 |
```
|
| 298 |
|
| 299 |
* GRAPH_STORAGE 支持的实现名称
|
|
|
|
| 301 |
```
|
| 302 |
NetworkXStorage NetworkX(默认)
|
| 303 |
Neo4JStorage Neo4J
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
PGGraphStorage Postgres
|
| 305 |
+
AGEStorage AGE
|
| 306 |
```
|
| 307 |
|
| 308 |
* VECTOR_STORAGE 支持的实现名称
|
| 309 |
|
| 310 |
```
|
| 311 |
NanoVectorDBStorage NanoVector(默认)
|
| 312 |
+
PGVectorStorage Postgres
|
| 313 |
MilvusVectorDBStorge Milvus
|
| 314 |
ChromaVectorDBStorage Chroma
|
|
|
|
|
|
|
| 315 |
FaissVectorDBStorage Faiss
|
| 316 |
QdrantVectorDBStorage Qdrant
|
|
|
|
| 317 |
MongoVectorDBStorage MongoDB
|
| 318 |
```
|
| 319 |
|
lightrag/api/README.md
CHANGED
|
@@ -302,11 +302,9 @@ Each storage type have servals implementations:
|
|
| 302 |
|
| 303 |
```
|
| 304 |
JsonKVStorage JsonFile(default)
|
| 305 |
-
MongoKVStorage MogonDB
|
| 306 |
-
RedisKVStorage Redis
|
| 307 |
-
TiDBKVStorage TiDB
|
| 308 |
PGKVStorage Postgres
|
| 309 |
-
|
|
|
|
| 310 |
```
|
| 311 |
|
| 312 |
* GRAPH_STORAGE supported implement-name
|
|
@@ -314,25 +312,19 @@ OracleKVStorage Oracle
|
|
| 314 |
```
|
| 315 |
NetworkXStorage NetworkX(defualt)
|
| 316 |
Neo4JStorage Neo4J
|
| 317 |
-
MongoGraphStorage MongoDB
|
| 318 |
-
TiDBGraphStorage TiDB
|
| 319 |
-
AGEStorage AGE
|
| 320 |
-
GremlinStorage Gremlin
|
| 321 |
PGGraphStorage Postgres
|
| 322 |
-
|
| 323 |
```
|
| 324 |
|
| 325 |
* VECTOR_STORAGE supported implement-name
|
| 326 |
|
| 327 |
```
|
| 328 |
NanoVectorDBStorage NanoVector(default)
|
| 329 |
-
MilvusVectorDBStorage Milvus
|
| 330 |
-
ChromaVectorDBStorage Chroma
|
| 331 |
-
TiDBVectorDBStorage TiDB
|
| 332 |
PGVectorStorage Postgres
|
|
|
|
|
|
|
| 333 |
FaissVectorDBStorage Faiss
|
| 334 |
QdrantVectorDBStorage Qdrant
|
| 335 |
-
OracleVectorDBStorage Oracle
|
| 336 |
MongoVectorDBStorage MongoDB
|
| 337 |
```
|
| 338 |
|
|
|
|
| 302 |
|
| 303 |
```
|
| 304 |
JsonKVStorage JsonFile(default)
|
|
|
|
|
|
|
|
|
|
| 305 |
PGKVStorage Postgres
|
| 306 |
+
RedisKVStorage Redis
|
| 307 |
+
MongoKVStorage MogonDB
|
| 308 |
```
|
| 309 |
|
| 310 |
* GRAPH_STORAGE supported implement-name
|
|
|
|
| 312 |
```
|
| 313 |
NetworkXStorage NetworkX(defualt)
|
| 314 |
Neo4JStorage Neo4J
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
PGGraphStorage Postgres
|
| 316 |
+
AGEStorage AGE
|
| 317 |
```
|
| 318 |
|
| 319 |
* VECTOR_STORAGE supported implement-name
|
| 320 |
|
| 321 |
```
|
| 322 |
NanoVectorDBStorage NanoVector(default)
|
|
|
|
|
|
|
|
|
|
| 323 |
PGVectorStorage Postgres
|
| 324 |
+
MilvusVectorDBStorge Milvus
|
| 325 |
+
ChromaVectorDBStorage Chroma
|
| 326 |
FaissVectorDBStorage Faiss
|
| 327 |
QdrantVectorDBStorage Qdrant
|
|
|
|
| 328 |
MongoVectorDBStorage MongoDB
|
| 329 |
```
|
| 330 |
|
lightrag/api/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
__api_version__ = "
|
|
|
|
| 1 |
+
__api_version__ = "0136"
|
lightrag/api/auth.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
-
import os
|
| 2 |
from datetime import datetime, timedelta
|
|
|
|
| 3 |
import jwt
|
|
|
|
| 4 |
from fastapi import HTTPException, status
|
| 5 |
from pydantic import BaseModel
|
| 6 |
-
|
|
|
|
| 7 |
|
| 8 |
# use the .env that is inside the current folder
|
| 9 |
# allows to use different .env file for each lightrag instance
|
|
@@ -20,13 +22,12 @@ class TokenPayload(BaseModel):
|
|
| 20 |
|
| 21 |
class AuthHandler:
|
| 22 |
def __init__(self):
|
| 23 |
-
self.secret =
|
| 24 |
-
self.algorithm =
|
| 25 |
-
self.expire_hours =
|
| 26 |
-
self.guest_expire_hours =
|
| 27 |
-
|
| 28 |
self.accounts = {}
|
| 29 |
-
auth_accounts =
|
| 30 |
if auth_accounts:
|
| 31 |
for account in auth_accounts.split(","):
|
| 32 |
username, password = account.split(":", 1)
|
|
|
|
|
|
|
| 1 |
from datetime import datetime, timedelta
|
| 2 |
+
|
| 3 |
import jwt
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
from fastapi import HTTPException, status
|
| 6 |
from pydantic import BaseModel
|
| 7 |
+
|
| 8 |
+
from .config import global_args
|
| 9 |
|
| 10 |
# use the .env that is inside the current folder
|
| 11 |
# allows to use different .env file for each lightrag instance
|
|
|
|
| 22 |
|
| 23 |
class AuthHandler:
|
| 24 |
def __init__(self):
|
| 25 |
+
self.secret = global_args.token_secret
|
| 26 |
+
self.algorithm = global_args.jwt_algorithm
|
| 27 |
+
self.expire_hours = global_args.token_expire_hours
|
| 28 |
+
self.guest_expire_hours = global_args.guest_token_expire_hours
|
|
|
|
| 29 |
self.accounts = {}
|
| 30 |
+
auth_accounts = global_args.auth_accounts
|
| 31 |
if auth_accounts:
|
| 32 |
for account in auth_accounts.split(","):
|
| 33 |
username, password = account.split(":", 1)
|
lightrag/api/config.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configs for the LightRAG API.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import argparse
|
| 7 |
+
import logging
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
# use the .env that is inside the current folder
|
| 11 |
+
# allows to use different .env file for each lightrag instance
|
| 12 |
+
# the OS environment variables take precedence over the .env file
|
| 13 |
+
load_dotenv(dotenv_path=".env", override=False)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class OllamaServerInfos:
|
| 17 |
+
# Constants for emulated Ollama model information
|
| 18 |
+
LIGHTRAG_NAME = "lightrag"
|
| 19 |
+
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
| 20 |
+
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
| 21 |
+
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
| 22 |
+
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
| 23 |
+
LIGHTRAG_DIGEST = "sha256:lightrag"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
ollama_server_infos = OllamaServerInfos()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DefaultRAGStorageConfig:
|
| 30 |
+
KV_STORAGE = "JsonKVStorage"
|
| 31 |
+
VECTOR_STORAGE = "NanoVectorDBStorage"
|
| 32 |
+
GRAPH_STORAGE = "NetworkXStorage"
|
| 33 |
+
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_default_host(binding_type: str) -> str:
|
| 37 |
+
default_hosts = {
|
| 38 |
+
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
| 39 |
+
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
| 40 |
+
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
| 41 |
+
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
| 42 |
+
}
|
| 43 |
+
return default_hosts.get(
|
| 44 |
+
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
| 45 |
+
) # fallback to ollama if unknown
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
|
| 49 |
+
"""
|
| 50 |
+
Get value from environment variable with type conversion
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
env_key (str): Environment variable key
|
| 54 |
+
default (any): Default value if env variable is not set
|
| 55 |
+
value_type (type): Type to convert the value to
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
any: Converted value from environment or default
|
| 59 |
+
"""
|
| 60 |
+
value = os.getenv(env_key)
|
| 61 |
+
if value is None:
|
| 62 |
+
return default
|
| 63 |
+
|
| 64 |
+
if value_type is bool:
|
| 65 |
+
return value.lower() in ("true", "1", "yes", "t", "on")
|
| 66 |
+
try:
|
| 67 |
+
return value_type(value)
|
| 68 |
+
except ValueError:
|
| 69 |
+
return default
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def parse_args() -> argparse.Namespace:
|
| 73 |
+
"""
|
| 74 |
+
Parse command line arguments with environment variable fallback
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
is_uvicorn_mode: Whether running under uvicorn mode
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
argparse.Namespace: Parsed arguments
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
parser = argparse.ArgumentParser(
|
| 84 |
+
description="LightRAG FastAPI Server with separate working and input directories"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Server configuration
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--host",
|
| 90 |
+
default=get_env_value("HOST", "0.0.0.0"),
|
| 91 |
+
help="Server host (default: from env or 0.0.0.0)",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--port",
|
| 95 |
+
type=int,
|
| 96 |
+
default=get_env_value("PORT", 9621, int),
|
| 97 |
+
help="Server port (default: from env or 9621)",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Directory configuration
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--working-dir",
|
| 103 |
+
default=get_env_value("WORKING_DIR", "./rag_storage"),
|
| 104 |
+
help="Working directory for RAG storage (default: from env or ./rag_storage)",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--input-dir",
|
| 108 |
+
default=get_env_value("INPUT_DIR", "./inputs"),
|
| 109 |
+
help="Directory containing input documents (default: from env or ./inputs)",
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def timeout_type(value):
|
| 113 |
+
if value is None:
|
| 114 |
+
return 150
|
| 115 |
+
if value is None or value == "None":
|
| 116 |
+
return None
|
| 117 |
+
return int(value)
|
| 118 |
+
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--timeout",
|
| 121 |
+
default=get_env_value("TIMEOUT", None, timeout_type),
|
| 122 |
+
type=timeout_type,
|
| 123 |
+
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# RAG configuration
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--max-async",
|
| 129 |
+
type=int,
|
| 130 |
+
default=get_env_value("MAX_ASYNC", 4, int),
|
| 131 |
+
help="Maximum async operations (default: from env or 4)",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--max-tokens",
|
| 135 |
+
type=int,
|
| 136 |
+
default=get_env_value("MAX_TOKENS", 32768, int),
|
| 137 |
+
help="Maximum token size (default: from env or 32768)",
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Logging configuration
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--log-level",
|
| 143 |
+
default=get_env_value("LOG_LEVEL", "INFO"),
|
| 144 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
| 145 |
+
help="Logging level (default: from env or INFO)",
|
| 146 |
+
)
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--verbose",
|
| 149 |
+
action="store_true",
|
| 150 |
+
default=get_env_value("VERBOSE", False, bool),
|
| 151 |
+
help="Enable verbose debug output(only valid for DEBUG log-level)",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--key",
|
| 156 |
+
type=str,
|
| 157 |
+
default=get_env_value("LIGHTRAG_API_KEY", None),
|
| 158 |
+
help="API key for authentication. This protects lightrag server against unauthorized access",
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Optional https parameters
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--ssl",
|
| 164 |
+
action="store_true",
|
| 165 |
+
default=get_env_value("SSL", False, bool),
|
| 166 |
+
help="Enable HTTPS (default: from env or False)",
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--ssl-certfile",
|
| 170 |
+
default=get_env_value("SSL_CERTFILE", None),
|
| 171 |
+
help="Path to SSL certificate file (required if --ssl is enabled)",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--ssl-keyfile",
|
| 175 |
+
default=get_env_value("SSL_KEYFILE", None),
|
| 176 |
+
help="Path to SSL private key file (required if --ssl is enabled)",
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--history-turns",
|
| 181 |
+
type=int,
|
| 182 |
+
default=get_env_value("HISTORY_TURNS", 3, int),
|
| 183 |
+
help="Number of conversation history turns to include (default: from env or 3)",
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Search parameters
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--top-k",
|
| 189 |
+
type=int,
|
| 190 |
+
default=get_env_value("TOP_K", 60, int),
|
| 191 |
+
help="Number of most similar results to return (default: from env or 60)",
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--cosine-threshold",
|
| 195 |
+
type=float,
|
| 196 |
+
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
|
| 197 |
+
help="Cosine similarity threshold (default: from env or 0.4)",
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Ollama model name
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--simulated-model-name",
|
| 203 |
+
type=str,
|
| 204 |
+
default=get_env_value(
|
| 205 |
+
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
|
| 206 |
+
),
|
| 207 |
+
help="Number of conversation history turns to include (default: from env or 3)",
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Namespace
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--namespace-prefix",
|
| 213 |
+
type=str,
|
| 214 |
+
default=get_env_value("NAMESPACE_PREFIX", ""),
|
| 215 |
+
help="Prefix of the namespace",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--auto-scan-at-startup",
|
| 220 |
+
action="store_true",
|
| 221 |
+
default=False,
|
| 222 |
+
help="Enable automatic scanning when the program starts",
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Server workers configuration
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--workers",
|
| 228 |
+
type=int,
|
| 229 |
+
default=get_env_value("WORKERS", 1, int),
|
| 230 |
+
help="Number of worker processes (default: from env or 1)",
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# LLM and embedding bindings
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--llm-binding",
|
| 236 |
+
type=str,
|
| 237 |
+
default=get_env_value("LLM_BINDING", "ollama"),
|
| 238 |
+
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
|
| 239 |
+
help="LLM binding type (default: from env or ollama)",
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--embedding-binding",
|
| 243 |
+
type=str,
|
| 244 |
+
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
| 245 |
+
choices=["lollms", "ollama", "openai", "azure_openai"],
|
| 246 |
+
help="Embedding binding type (default: from env or ollama)",
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
args = parser.parse_args()
|
| 250 |
+
|
| 251 |
+
# convert relative path to absolute path
|
| 252 |
+
args.working_dir = os.path.abspath(args.working_dir)
|
| 253 |
+
args.input_dir = os.path.abspath(args.input_dir)
|
| 254 |
+
|
| 255 |
+
# Inject storage configuration from environment variables
|
| 256 |
+
args.kv_storage = get_env_value(
|
| 257 |
+
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
| 258 |
+
)
|
| 259 |
+
args.doc_status_storage = get_env_value(
|
| 260 |
+
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
| 261 |
+
)
|
| 262 |
+
args.graph_storage = get_env_value(
|
| 263 |
+
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
| 264 |
+
)
|
| 265 |
+
args.vector_storage = get_env_value(
|
| 266 |
+
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Get MAX_PARALLEL_INSERT from environment
|
| 270 |
+
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
| 271 |
+
|
| 272 |
+
# Handle openai-ollama special case
|
| 273 |
+
if args.llm_binding == "openai-ollama":
|
| 274 |
+
args.llm_binding = "openai"
|
| 275 |
+
args.embedding_binding = "ollama"
|
| 276 |
+
|
| 277 |
+
args.llm_binding_host = get_env_value(
|
| 278 |
+
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
|
| 279 |
+
)
|
| 280 |
+
args.embedding_binding_host = get_env_value(
|
| 281 |
+
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
|
| 282 |
+
)
|
| 283 |
+
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
|
| 284 |
+
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
|
| 285 |
+
|
| 286 |
+
# Inject model configuration
|
| 287 |
+
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
| 288 |
+
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
|
| 289 |
+
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
|
| 290 |
+
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
|
| 291 |
+
|
| 292 |
+
# Inject chunk configuration
|
| 293 |
+
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
| 294 |
+
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
| 295 |
+
|
| 296 |
+
# Inject LLM cache configuration
|
| 297 |
+
args.enable_llm_cache_for_extract = get_env_value(
|
| 298 |
+
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Inject LLM temperature configuration
|
| 302 |
+
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
|
| 303 |
+
|
| 304 |
+
# Select Document loading tool (DOCLING, DEFAULT)
|
| 305 |
+
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
| 306 |
+
|
| 307 |
+
# Add environment variables that were previously read directly
|
| 308 |
+
args.cors_origins = get_env_value("CORS_ORIGINS", "*")
|
| 309 |
+
args.summary_language = get_env_value("SUMMARY_LANGUAGE", "en")
|
| 310 |
+
args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
|
| 311 |
+
|
| 312 |
+
# For JWT Auth
|
| 313 |
+
args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "")
|
| 314 |
+
args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret")
|
| 315 |
+
args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int)
|
| 316 |
+
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
|
| 317 |
+
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
|
| 318 |
+
|
| 319 |
+
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
| 320 |
+
|
| 321 |
+
return args
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def update_uvicorn_mode_config():
|
| 325 |
+
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
| 326 |
+
if global_args.workers > 1:
|
| 327 |
+
original_workers = global_args.workers
|
| 328 |
+
global_args.workers = 1
|
| 329 |
+
# Log warning directly here
|
| 330 |
+
logging.warning(
|
| 331 |
+
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
global_args = parse_args()
|
lightrag/api/lightrag_server.py
CHANGED
|
@@ -19,11 +19,14 @@ from contextlib import asynccontextmanager
|
|
| 19 |
from dotenv import load_dotenv
|
| 20 |
from lightrag.api.utils_api import (
|
| 21 |
get_combined_auth_dependency,
|
| 22 |
-
parse_args,
|
| 23 |
-
get_default_host,
|
| 24 |
display_splash_screen,
|
| 25 |
check_env_file,
|
| 26 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
import sys
|
| 28 |
from lightrag import LightRAG, __version__ as core_version
|
| 29 |
from lightrag.api import __api_version__
|
|
@@ -52,6 +55,10 @@ from lightrag.api.auth import auth_handler
|
|
| 52 |
# the OS environment variables take precedence over the .env file
|
| 53 |
load_dotenv(dotenv_path=".env", override=False)
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# Initialize config parser
|
| 56 |
config = configparser.ConfigParser()
|
| 57 |
config.read("config.ini")
|
|
@@ -164,10 +171,10 @@ def create_app(args):
|
|
| 164 |
app = FastAPI(**app_kwargs)
|
| 165 |
|
| 166 |
def get_cors_origins():
|
| 167 |
-
"""Get allowed origins from
|
| 168 |
Returns a list of allowed origins, defaults to ["*"] if not set
|
| 169 |
"""
|
| 170 |
-
origins_str =
|
| 171 |
if origins_str == "*":
|
| 172 |
return ["*"]
|
| 173 |
return [origin.strip() for origin in origins_str.split(",")]
|
|
@@ -315,9 +322,10 @@ def create_app(args):
|
|
| 315 |
"similarity_threshold": 0.95,
|
| 316 |
"use_llm_check": False,
|
| 317 |
},
|
| 318 |
-
namespace_prefix=args.namespace_prefix,
|
| 319 |
auto_manage_storages_states=False,
|
| 320 |
max_parallel_insert=args.max_parallel_insert,
|
|
|
|
| 321 |
)
|
| 322 |
else: # azure_openai
|
| 323 |
rag = LightRAG(
|
|
@@ -345,9 +353,10 @@ def create_app(args):
|
|
| 345 |
"similarity_threshold": 0.95,
|
| 346 |
"use_llm_check": False,
|
| 347 |
},
|
| 348 |
-
namespace_prefix=args.namespace_prefix,
|
| 349 |
auto_manage_storages_states=False,
|
| 350 |
max_parallel_insert=args.max_parallel_insert,
|
|
|
|
| 351 |
)
|
| 352 |
|
| 353 |
# Add routes
|
|
@@ -381,6 +390,8 @@ def create_app(args):
|
|
| 381 |
"message": "Authentication is disabled. Using guest access.",
|
| 382 |
"core_version": core_version,
|
| 383 |
"api_version": __api_version__,
|
|
|
|
|
|
|
| 384 |
}
|
| 385 |
|
| 386 |
return {
|
|
@@ -388,6 +399,8 @@ def create_app(args):
|
|
| 388 |
"auth_mode": "enabled",
|
| 389 |
"core_version": core_version,
|
| 390 |
"api_version": __api_version__,
|
|
|
|
|
|
|
| 391 |
}
|
| 392 |
|
| 393 |
@app.post("/login")
|
|
@@ -404,6 +417,8 @@ def create_app(args):
|
|
| 404 |
"message": "Authentication is disabled. Using guest access.",
|
| 405 |
"core_version": core_version,
|
| 406 |
"api_version": __api_version__,
|
|
|
|
|
|
|
| 407 |
}
|
| 408 |
username = form_data.username
|
| 409 |
if auth_handler.accounts.get(username) != form_data.password:
|
|
@@ -421,6 +436,8 @@ def create_app(args):
|
|
| 421 |
"auth_mode": "enabled",
|
| 422 |
"core_version": core_version,
|
| 423 |
"api_version": __api_version__,
|
|
|
|
|
|
|
| 424 |
}
|
| 425 |
|
| 426 |
@app.get("/health", dependencies=[Depends(combined_auth)])
|
|
@@ -454,10 +471,12 @@ def create_app(args):
|
|
| 454 |
"vector_storage": args.vector_storage,
|
| 455 |
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
| 456 |
},
|
| 457 |
-
"core_version": core_version,
|
| 458 |
-
"api_version": __api_version__,
|
| 459 |
"auth_mode": auth_mode,
|
| 460 |
"pipeline_busy": pipeline_status.get("busy", False),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
}
|
| 462 |
except Exception as e:
|
| 463 |
logger.error(f"Error getting health status: {str(e)}")
|
|
@@ -490,7 +509,7 @@ def create_app(args):
|
|
| 490 |
def get_application(args=None):
|
| 491 |
"""Factory function for creating the FastAPI application"""
|
| 492 |
if args is None:
|
| 493 |
-
args =
|
| 494 |
return create_app(args)
|
| 495 |
|
| 496 |
|
|
@@ -611,30 +630,31 @@ def main():
|
|
| 611 |
|
| 612 |
# Configure logging before parsing args
|
| 613 |
configure_logging()
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
display_splash_screen(args)
|
| 617 |
|
| 618 |
# Create application instance directly instead of using factory function
|
| 619 |
-
app = create_app(
|
| 620 |
|
| 621 |
# Start Uvicorn in single process mode
|
| 622 |
uvicorn_config = {
|
| 623 |
"app": app, # Pass application instance directly instead of string path
|
| 624 |
-
"host":
|
| 625 |
-
"port":
|
| 626 |
"log_config": None, # Disable default config
|
| 627 |
}
|
| 628 |
|
| 629 |
-
if
|
| 630 |
uvicorn_config.update(
|
| 631 |
{
|
| 632 |
-
"ssl_certfile":
|
| 633 |
-
"ssl_keyfile":
|
| 634 |
}
|
| 635 |
)
|
| 636 |
|
| 637 |
-
print(
|
|
|
|
|
|
|
| 638 |
uvicorn.run(**uvicorn_config)
|
| 639 |
|
| 640 |
|
|
|
|
| 19 |
from dotenv import load_dotenv
|
| 20 |
from lightrag.api.utils_api import (
|
| 21 |
get_combined_auth_dependency,
|
|
|
|
|
|
|
| 22 |
display_splash_screen,
|
| 23 |
check_env_file,
|
| 24 |
)
|
| 25 |
+
from .config import (
|
| 26 |
+
global_args,
|
| 27 |
+
update_uvicorn_mode_config,
|
| 28 |
+
get_default_host,
|
| 29 |
+
)
|
| 30 |
import sys
|
| 31 |
from lightrag import LightRAG, __version__ as core_version
|
| 32 |
from lightrag.api import __api_version__
|
|
|
|
| 55 |
# the OS environment variables take precedence over the .env file
|
| 56 |
load_dotenv(dotenv_path=".env", override=False)
|
| 57 |
|
| 58 |
+
|
| 59 |
+
webui_title = os.getenv("WEBUI_TITLE")
|
| 60 |
+
webui_description = os.getenv("WEBUI_DESCRIPTION")
|
| 61 |
+
|
| 62 |
# Initialize config parser
|
| 63 |
config = configparser.ConfigParser()
|
| 64 |
config.read("config.ini")
|
|
|
|
| 171 |
app = FastAPI(**app_kwargs)
|
| 172 |
|
| 173 |
def get_cors_origins():
|
| 174 |
+
"""Get allowed origins from global_args
|
| 175 |
Returns a list of allowed origins, defaults to ["*"] if not set
|
| 176 |
"""
|
| 177 |
+
origins_str = global_args.cors_origins
|
| 178 |
if origins_str == "*":
|
| 179 |
return ["*"]
|
| 180 |
return [origin.strip() for origin in origins_str.split(",")]
|
|
|
|
| 322 |
"similarity_threshold": 0.95,
|
| 323 |
"use_llm_check": False,
|
| 324 |
},
|
| 325 |
+
# namespace_prefix=args.namespace_prefix,
|
| 326 |
auto_manage_storages_states=False,
|
| 327 |
max_parallel_insert=args.max_parallel_insert,
|
| 328 |
+
addon_params={"language": args.summary_language},
|
| 329 |
)
|
| 330 |
else: # azure_openai
|
| 331 |
rag = LightRAG(
|
|
|
|
| 353 |
"similarity_threshold": 0.95,
|
| 354 |
"use_llm_check": False,
|
| 355 |
},
|
| 356 |
+
# namespace_prefix=args.namespace_prefix,
|
| 357 |
auto_manage_storages_states=False,
|
| 358 |
max_parallel_insert=args.max_parallel_insert,
|
| 359 |
+
addon_params={"language": args.summary_language},
|
| 360 |
)
|
| 361 |
|
| 362 |
# Add routes
|
|
|
|
| 390 |
"message": "Authentication is disabled. Using guest access.",
|
| 391 |
"core_version": core_version,
|
| 392 |
"api_version": __api_version__,
|
| 393 |
+
"webui_title": webui_title,
|
| 394 |
+
"webui_description": webui_description,
|
| 395 |
}
|
| 396 |
|
| 397 |
return {
|
|
|
|
| 399 |
"auth_mode": "enabled",
|
| 400 |
"core_version": core_version,
|
| 401 |
"api_version": __api_version__,
|
| 402 |
+
"webui_title": webui_title,
|
| 403 |
+
"webui_description": webui_description,
|
| 404 |
}
|
| 405 |
|
| 406 |
@app.post("/login")
|
|
|
|
| 417 |
"message": "Authentication is disabled. Using guest access.",
|
| 418 |
"core_version": core_version,
|
| 419 |
"api_version": __api_version__,
|
| 420 |
+
"webui_title": webui_title,
|
| 421 |
+
"webui_description": webui_description,
|
| 422 |
}
|
| 423 |
username = form_data.username
|
| 424 |
if auth_handler.accounts.get(username) != form_data.password:
|
|
|
|
| 436 |
"auth_mode": "enabled",
|
| 437 |
"core_version": core_version,
|
| 438 |
"api_version": __api_version__,
|
| 439 |
+
"webui_title": webui_title,
|
| 440 |
+
"webui_description": webui_description,
|
| 441 |
}
|
| 442 |
|
| 443 |
@app.get("/health", dependencies=[Depends(combined_auth)])
|
|
|
|
| 471 |
"vector_storage": args.vector_storage,
|
| 472 |
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
| 473 |
},
|
|
|
|
|
|
|
| 474 |
"auth_mode": auth_mode,
|
| 475 |
"pipeline_busy": pipeline_status.get("busy", False),
|
| 476 |
+
"core_version": core_version,
|
| 477 |
+
"api_version": __api_version__,
|
| 478 |
+
"webui_title": webui_title,
|
| 479 |
+
"webui_description": webui_description,
|
| 480 |
}
|
| 481 |
except Exception as e:
|
| 482 |
logger.error(f"Error getting health status: {str(e)}")
|
|
|
|
| 509 |
def get_application(args=None):
|
| 510 |
"""Factory function for creating the FastAPI application"""
|
| 511 |
if args is None:
|
| 512 |
+
args = global_args
|
| 513 |
return create_app(args)
|
| 514 |
|
| 515 |
|
|
|
|
| 630 |
|
| 631 |
# Configure logging before parsing args
|
| 632 |
configure_logging()
|
| 633 |
+
update_uvicorn_mode_config()
|
| 634 |
+
display_splash_screen(global_args)
|
|
|
|
| 635 |
|
| 636 |
# Create application instance directly instead of using factory function
|
| 637 |
+
app = create_app(global_args)
|
| 638 |
|
| 639 |
# Start Uvicorn in single process mode
|
| 640 |
uvicorn_config = {
|
| 641 |
"app": app, # Pass application instance directly instead of string path
|
| 642 |
+
"host": global_args.host,
|
| 643 |
+
"port": global_args.port,
|
| 644 |
"log_config": None, # Disable default config
|
| 645 |
}
|
| 646 |
|
| 647 |
+
if global_args.ssl:
|
| 648 |
uvicorn_config.update(
|
| 649 |
{
|
| 650 |
+
"ssl_certfile": global_args.ssl_certfile,
|
| 651 |
+
"ssl_keyfile": global_args.ssl_keyfile,
|
| 652 |
}
|
| 653 |
)
|
| 654 |
|
| 655 |
+
print(
|
| 656 |
+
f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
|
| 657 |
+
)
|
| 658 |
uvicorn.run(**uvicorn_config)
|
| 659 |
|
| 660 |
|
lightrag/api/routers/document_routes.py
CHANGED
|
@@ -10,16 +10,14 @@ import traceback
|
|
| 10 |
import pipmaster as pm
|
| 11 |
from datetime import datetime
|
| 12 |
from pathlib import Path
|
| 13 |
-
from typing import Dict, List, Optional, Any
|
| 14 |
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
| 15 |
from pydantic import BaseModel, Field, field_validator
|
| 16 |
|
| 17 |
from lightrag import LightRAG
|
| 18 |
from lightrag.base import DocProcessingStatus, DocStatus
|
| 19 |
-
from lightrag.api.utils_api import
|
| 20 |
-
|
| 21 |
-
global_args,
|
| 22 |
-
)
|
| 23 |
|
| 24 |
router = APIRouter(
|
| 25 |
prefix="/documents",
|
|
@@ -30,7 +28,37 @@ router = APIRouter(
|
|
| 30 |
temp_prefix = "__tmp__"
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
class InsertTextRequest(BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
text: str = Field(
|
| 35 |
min_length=1,
|
| 36 |
description="The text to insert",
|
|
@@ -41,8 +69,21 @@ class InsertTextRequest(BaseModel):
|
|
| 41 |
def strip_after(cls, text: str) -> str:
|
| 42 |
return text.strip()
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
class InsertTextsRequest(BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
texts: list[str] = Field(
|
| 47 |
min_length=1,
|
| 48 |
description="The texts to insert",
|
|
@@ -53,11 +94,116 @@ class InsertTextsRequest(BaseModel):
|
|
| 53 |
def strip_after(cls, texts: list[str]) -> list[str]:
|
| 54 |
return [text.strip() for text in texts]
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
class InsertResponse(BaseModel):
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
message: str = Field(description="Message describing the operation result")
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
class DocStatusResponse(BaseModel):
|
| 63 |
@staticmethod
|
|
@@ -68,34 +214,82 @@ class DocStatusResponse(BaseModel):
|
|
| 68 |
return dt
|
| 69 |
return dt.isoformat()
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
Attributes:
|
| 74 |
-
|
| 75 |
-
content_summary: Summary of document content
|
| 76 |
-
content_length: Length of document content
|
| 77 |
-
status: Current processing status
|
| 78 |
-
created_at: Creation timestamp (ISO format string)
|
| 79 |
-
updated_at: Last update timestamp (ISO format string)
|
| 80 |
-
chunks_count: Number of chunks (optional)
|
| 81 |
-
error: Error message if any (optional)
|
| 82 |
-
metadata: Additional metadata (optional)
|
| 83 |
"""
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
created_at: str
|
| 90 |
-
updated_at: str
|
| 91 |
-
chunks_count: Optional[int] = None
|
| 92 |
-
error: Optional[str] = None
|
| 93 |
-
metadata: Optional[dict[str, Any]] = None
|
| 94 |
-
file_path: str
|
| 95 |
-
|
| 96 |
|
| 97 |
-
class
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
class PipelineStatusResponse(BaseModel):
|
|
@@ -276,7 +470,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
|
| 276 |
)
|
| 277 |
return False
|
| 278 |
case ".pdf":
|
| 279 |
-
if global_args
|
| 280 |
if not pm.is_installed("docling"): # type: ignore
|
| 281 |
pm.install("docling")
|
| 282 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
@@ -295,7 +489,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
|
| 295 |
for page in reader.pages:
|
| 296 |
content += page.extract_text() + "\n"
|
| 297 |
case ".docx":
|
| 298 |
-
if global_args
|
| 299 |
if not pm.is_installed("docling"): # type: ignore
|
| 300 |
pm.install("docling")
|
| 301 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
@@ -315,7 +509,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
|
| 315 |
[paragraph.text for paragraph in doc.paragraphs]
|
| 316 |
)
|
| 317 |
case ".pptx":
|
| 318 |
-
if global_args
|
| 319 |
if not pm.is_installed("docling"): # type: ignore
|
| 320 |
pm.install("docling")
|
| 321 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
@@ -336,7 +530,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
|
| 336 |
if hasattr(shape, "text"):
|
| 337 |
content += shape.text + "\n"
|
| 338 |
case ".xlsx":
|
| 339 |
-
if global_args
|
| 340 |
if not pm.is_installed("docling"): # type: ignore
|
| 341 |
pm.install("docling")
|
| 342 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
@@ -443,6 +637,7 @@ async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
|
|
| 443 |
await rag.apipeline_process_enqueue_documents()
|
| 444 |
|
| 445 |
|
|
|
|
| 446 |
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
| 447 |
"""Save the uploaded file to a temporary location
|
| 448 |
|
|
@@ -476,8 +671,8 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
|
| 476 |
if not new_files:
|
| 477 |
return
|
| 478 |
|
| 479 |
-
# Get MAX_PARALLEL_INSERT from global_args
|
| 480 |
-
max_parallel = global_args
|
| 481 |
# Calculate batch size as 2 * MAX_PARALLEL_INSERT
|
| 482 |
batch_size = 2 * max_parallel
|
| 483 |
|
|
@@ -509,7 +704,9 @@ def create_document_routes(
|
|
| 509 |
# Create combined auth dependency for document routes
|
| 510 |
combined_auth = get_combined_auth_dependency(api_key)
|
| 511 |
|
| 512 |
-
@router.post(
|
|
|
|
|
|
|
| 513 |
async def scan_for_new_documents(background_tasks: BackgroundTasks):
|
| 514 |
"""
|
| 515 |
Trigger the scanning process for new documents.
|
|
@@ -519,13 +716,18 @@ def create_document_routes(
|
|
| 519 |
that fact.
|
| 520 |
|
| 521 |
Returns:
|
| 522 |
-
|
| 523 |
"""
|
| 524 |
# Start the scanning process in the background
|
| 525 |
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
| 526 |
-
return
|
|
|
|
|
|
|
|
|
|
| 527 |
|
| 528 |
-
@router.post(
|
|
|
|
|
|
|
| 529 |
async def upload_to_input_dir(
|
| 530 |
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
| 531 |
):
|
|
@@ -645,6 +847,7 @@ def create_document_routes(
|
|
| 645 |
logger.error(traceback.format_exc())
|
| 646 |
raise HTTPException(status_code=500, detail=str(e))
|
| 647 |
|
|
|
|
| 648 |
@router.post(
|
| 649 |
"/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
| 650 |
)
|
|
@@ -688,6 +891,7 @@ def create_document_routes(
|
|
| 688 |
logger.error(traceback.format_exc())
|
| 689 |
raise HTTPException(status_code=500, detail=str(e))
|
| 690 |
|
|
|
|
| 691 |
@router.post(
|
| 692 |
"/file_batch",
|
| 693 |
response_model=InsertResponse,
|
|
@@ -752,32 +956,186 @@ def create_document_routes(
|
|
| 752 |
raise HTTPException(status_code=500, detail=str(e))
|
| 753 |
|
| 754 |
@router.delete(
|
| 755 |
-
"", response_model=
|
| 756 |
)
|
| 757 |
async def clear_documents():
|
| 758 |
"""
|
| 759 |
Clear all documents from the RAG system.
|
| 760 |
|
| 761 |
-
This endpoint deletes all
|
| 762 |
-
|
|
|
|
| 763 |
|
| 764 |
Returns:
|
| 765 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 766 |
|
| 767 |
Raises:
|
| 768 |
-
HTTPException:
|
|
|
|
| 769 |
"""
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
except Exception as e:
|
| 778 |
-
|
|
|
|
| 779 |
logger.error(traceback.format_exc())
|
|
|
|
|
|
|
| 780 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
|
| 782 |
@router.get(
|
| 783 |
"/pipeline_status",
|
|
@@ -850,7 +1208,9 @@ def create_document_routes(
|
|
| 850 |
logger.error(traceback.format_exc())
|
| 851 |
raise HTTPException(status_code=500, detail=str(e))
|
| 852 |
|
| 853 |
-
@router.get(
|
|
|
|
|
|
|
| 854 |
async def documents() -> DocsStatusesResponse:
|
| 855 |
"""
|
| 856 |
Get the status of all documents in the system.
|
|
@@ -908,4 +1268,57 @@ def create_document_routes(
|
|
| 908 |
logger.error(traceback.format_exc())
|
| 909 |
raise HTTPException(status_code=500, detail=str(e))
|
| 910 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 911 |
return router
|
|
|
|
| 10 |
import pipmaster as pm
|
| 11 |
from datetime import datetime
|
| 12 |
from pathlib import Path
|
| 13 |
+
from typing import Dict, List, Optional, Any, Literal
|
| 14 |
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
| 15 |
from pydantic import BaseModel, Field, field_validator
|
| 16 |
|
| 17 |
from lightrag import LightRAG
|
| 18 |
from lightrag.base import DocProcessingStatus, DocStatus
|
| 19 |
+
from lightrag.api.utils_api import get_combined_auth_dependency
|
| 20 |
+
from ..config import global_args
|
|
|
|
|
|
|
| 21 |
|
| 22 |
router = APIRouter(
|
| 23 |
prefix="/documents",
|
|
|
|
| 28 |
temp_prefix = "__tmp__"
|
| 29 |
|
| 30 |
|
| 31 |
+
class ScanResponse(BaseModel):
|
| 32 |
+
"""Response model for document scanning operation
|
| 33 |
+
|
| 34 |
+
Attributes:
|
| 35 |
+
status: Status of the scanning operation
|
| 36 |
+
message: Optional message with additional details
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
status: Literal["scanning_started"] = Field(
|
| 40 |
+
description="Status of the scanning operation"
|
| 41 |
+
)
|
| 42 |
+
message: Optional[str] = Field(
|
| 43 |
+
default=None, description="Additional details about the scanning operation"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
class Config:
|
| 47 |
+
json_schema_extra = {
|
| 48 |
+
"example": {
|
| 49 |
+
"status": "scanning_started",
|
| 50 |
+
"message": "Scanning process has been initiated in the background",
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
class InsertTextRequest(BaseModel):
|
| 56 |
+
"""Request model for inserting a single text document
|
| 57 |
+
|
| 58 |
+
Attributes:
|
| 59 |
+
text: The text content to be inserted into the RAG system
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
text: str = Field(
|
| 63 |
min_length=1,
|
| 64 |
description="The text to insert",
|
|
|
|
| 69 |
def strip_after(cls, text: str) -> str:
|
| 70 |
return text.strip()
|
| 71 |
|
| 72 |
+
class Config:
|
| 73 |
+
json_schema_extra = {
|
| 74 |
+
"example": {
|
| 75 |
+
"text": "This is a sample text to be inserted into the RAG system."
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
|
| 80 |
class InsertTextsRequest(BaseModel):
|
| 81 |
+
"""Request model for inserting multiple text documents
|
| 82 |
+
|
| 83 |
+
Attributes:
|
| 84 |
+
texts: List of text contents to be inserted into the RAG system
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
texts: list[str] = Field(
|
| 88 |
min_length=1,
|
| 89 |
description="The texts to insert",
|
|
|
|
| 94 |
def strip_after(cls, texts: list[str]) -> list[str]:
|
| 95 |
return [text.strip() for text in texts]
|
| 96 |
|
| 97 |
+
class Config:
|
| 98 |
+
json_schema_extra = {
|
| 99 |
+
"example": {
|
| 100 |
+
"texts": [
|
| 101 |
+
"This is the first text to be inserted.",
|
| 102 |
+
"This is the second text to be inserted.",
|
| 103 |
+
]
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
|
| 108 |
class InsertResponse(BaseModel):
|
| 109 |
+
"""Response model for document insertion operations
|
| 110 |
+
|
| 111 |
+
Attributes:
|
| 112 |
+
status: Status of the operation (success, duplicated, partial_success, failure)
|
| 113 |
+
message: Detailed message describing the operation result
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
|
| 117 |
+
description="Status of the operation"
|
| 118 |
+
)
|
| 119 |
message: str = Field(description="Message describing the operation result")
|
| 120 |
|
| 121 |
+
class Config:
|
| 122 |
+
json_schema_extra = {
|
| 123 |
+
"example": {
|
| 124 |
+
"status": "success",
|
| 125 |
+
"message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class ClearDocumentsResponse(BaseModel):
|
| 131 |
+
"""Response model for document clearing operation
|
| 132 |
+
|
| 133 |
+
Attributes:
|
| 134 |
+
status: Status of the clear operation
|
| 135 |
+
message: Detailed message describing the operation result
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
status: Literal["success", "partial_success", "busy", "fail"] = Field(
|
| 139 |
+
description="Status of the clear operation"
|
| 140 |
+
)
|
| 141 |
+
message: str = Field(description="Message describing the operation result")
|
| 142 |
+
|
| 143 |
+
class Config:
|
| 144 |
+
json_schema_extra = {
|
| 145 |
+
"example": {
|
| 146 |
+
"status": "success",
|
| 147 |
+
"message": "All documents cleared successfully. Deleted 15 files.",
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ClearCacheRequest(BaseModel):
|
| 153 |
+
"""Request model for clearing cache
|
| 154 |
+
|
| 155 |
+
Attributes:
|
| 156 |
+
modes: Optional list of cache modes to clear
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
modes: Optional[
|
| 160 |
+
List[Literal["default", "naive", "local", "global", "hybrid", "mix"]]
|
| 161 |
+
] = Field(
|
| 162 |
+
default=None,
|
| 163 |
+
description="Modes of cache to clear. If None, clears all cache.",
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
class Config:
|
| 167 |
+
json_schema_extra = {"example": {"modes": ["default", "naive"]}}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class ClearCacheResponse(BaseModel):
|
| 171 |
+
"""Response model for cache clearing operation
|
| 172 |
+
|
| 173 |
+
Attributes:
|
| 174 |
+
status: Status of the clear operation
|
| 175 |
+
message: Detailed message describing the operation result
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
status: Literal["success", "fail"] = Field(
|
| 179 |
+
description="Status of the clear operation"
|
| 180 |
+
)
|
| 181 |
+
message: str = Field(description="Message describing the operation result")
|
| 182 |
+
|
| 183 |
+
class Config:
|
| 184 |
+
json_schema_extra = {
|
| 185 |
+
"example": {
|
| 186 |
+
"status": "success",
|
| 187 |
+
"message": "Successfully cleared cache for modes: ['default', 'naive']",
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
"""Response model for document status
|
| 193 |
+
|
| 194 |
+
Attributes:
|
| 195 |
+
id: Document identifier
|
| 196 |
+
content_summary: Summary of document content
|
| 197 |
+
content_length: Length of document content
|
| 198 |
+
status: Current processing status
|
| 199 |
+
created_at: Creation timestamp (ISO format string)
|
| 200 |
+
updated_at: Last update timestamp (ISO format string)
|
| 201 |
+
chunks_count: Number of chunks (optional)
|
| 202 |
+
error: Error message if any (optional)
|
| 203 |
+
metadata: Additional metadata (optional)
|
| 204 |
+
file_path: Path to the document file
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
|
| 208 |
class DocStatusResponse(BaseModel):
|
| 209 |
@staticmethod
|
|
|
|
| 214 |
return dt
|
| 215 |
return dt.isoformat()
|
| 216 |
|
| 217 |
+
id: str = Field(description="Document identifier")
|
| 218 |
+
content_summary: str = Field(description="Summary of document content")
|
| 219 |
+
content_length: int = Field(description="Length of document content in characters")
|
| 220 |
+
status: DocStatus = Field(description="Current processing status")
|
| 221 |
+
created_at: str = Field(description="Creation timestamp (ISO format string)")
|
| 222 |
+
updated_at: str = Field(description="Last update timestamp (ISO format string)")
|
| 223 |
+
chunks_count: Optional[int] = Field(
|
| 224 |
+
default=None, description="Number of chunks the document was split into"
|
| 225 |
+
)
|
| 226 |
+
error: Optional[str] = Field(
|
| 227 |
+
default=None, description="Error message if processing failed"
|
| 228 |
+
)
|
| 229 |
+
metadata: Optional[dict[str, Any]] = Field(
|
| 230 |
+
default=None, description="Additional metadata about the document"
|
| 231 |
+
)
|
| 232 |
+
file_path: str = Field(description="Path to the document file")
|
| 233 |
+
|
| 234 |
+
class Config:
|
| 235 |
+
json_schema_extra = {
|
| 236 |
+
"example": {
|
| 237 |
+
"id": "doc_123456",
|
| 238 |
+
"content_summary": "Research paper on machine learning",
|
| 239 |
+
"content_length": 15240,
|
| 240 |
+
"status": "PROCESSED",
|
| 241 |
+
"created_at": "2025-03-31T12:34:56",
|
| 242 |
+
"updated_at": "2025-03-31T12:35:30",
|
| 243 |
+
"chunks_count": 12,
|
| 244 |
+
"error": None,
|
| 245 |
+
"metadata": {"author": "John Doe", "year": 2025},
|
| 246 |
+
"file_path": "research_paper.pdf",
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class DocsStatusesResponse(BaseModel):
|
| 252 |
+
"""Response model for document statuses
|
| 253 |
|
| 254 |
Attributes:
|
| 255 |
+
statuses: Dictionary mapping document status to lists of document status responses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
"""
|
| 257 |
|
| 258 |
+
statuses: Dict[DocStatus, List[DocStatusResponse]] = Field(
|
| 259 |
+
default_factory=dict,
|
| 260 |
+
description="Dictionary mapping document status to lists of document status responses",
|
| 261 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
+
class Config:
|
| 264 |
+
json_schema_extra = {
|
| 265 |
+
"example": {
|
| 266 |
+
"statuses": {
|
| 267 |
+
"PENDING": [
|
| 268 |
+
{
|
| 269 |
+
"id": "doc_123",
|
| 270 |
+
"content_summary": "Pending document",
|
| 271 |
+
"content_length": 5000,
|
| 272 |
+
"status": "PENDING",
|
| 273 |
+
"created_at": "2025-03-31T10:00:00",
|
| 274 |
+
"updated_at": "2025-03-31T10:00:00",
|
| 275 |
+
"file_path": "pending_doc.pdf",
|
| 276 |
+
}
|
| 277 |
+
],
|
| 278 |
+
"PROCESSED": [
|
| 279 |
+
{
|
| 280 |
+
"id": "doc_456",
|
| 281 |
+
"content_summary": "Processed document",
|
| 282 |
+
"content_length": 8000,
|
| 283 |
+
"status": "PROCESSED",
|
| 284 |
+
"created_at": "2025-03-31T09:00:00",
|
| 285 |
+
"updated_at": "2025-03-31T09:05:00",
|
| 286 |
+
"chunks_count": 8,
|
| 287 |
+
"file_path": "processed_doc.pdf",
|
| 288 |
+
}
|
| 289 |
+
],
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
|
| 294 |
|
| 295 |
class PipelineStatusResponse(BaseModel):
|
|
|
|
| 470 |
)
|
| 471 |
return False
|
| 472 |
case ".pdf":
|
| 473 |
+
if global_args.document_loading_engine == "DOCLING":
|
| 474 |
if not pm.is_installed("docling"): # type: ignore
|
| 475 |
pm.install("docling")
|
| 476 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
|
|
| 489 |
for page in reader.pages:
|
| 490 |
content += page.extract_text() + "\n"
|
| 491 |
case ".docx":
|
| 492 |
+
if global_args.document_loading_engine == "DOCLING":
|
| 493 |
if not pm.is_installed("docling"): # type: ignore
|
| 494 |
pm.install("docling")
|
| 495 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
|
|
| 509 |
[paragraph.text for paragraph in doc.paragraphs]
|
| 510 |
)
|
| 511 |
case ".pptx":
|
| 512 |
+
if global_args.document_loading_engine == "DOCLING":
|
| 513 |
if not pm.is_installed("docling"): # type: ignore
|
| 514 |
pm.install("docling")
|
| 515 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
|
|
| 530 |
if hasattr(shape, "text"):
|
| 531 |
content += shape.text + "\n"
|
| 532 |
case ".xlsx":
|
| 533 |
+
if global_args.document_loading_engine == "DOCLING":
|
| 534 |
if not pm.is_installed("docling"): # type: ignore
|
| 535 |
pm.install("docling")
|
| 536 |
from docling.document_converter import DocumentConverter # type: ignore
|
|
|
|
| 637 |
await rag.apipeline_process_enqueue_documents()
|
| 638 |
|
| 639 |
|
| 640 |
+
# TODO: deprecate after /insert_file is removed
|
| 641 |
async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
|
| 642 |
"""Save the uploaded file to a temporary location
|
| 643 |
|
|
|
|
| 671 |
if not new_files:
|
| 672 |
return
|
| 673 |
|
| 674 |
+
# Get MAX_PARALLEL_INSERT from global_args
|
| 675 |
+
max_parallel = global_args.max_parallel_insert
|
| 676 |
# Calculate batch size as 2 * MAX_PARALLEL_INSERT
|
| 677 |
batch_size = 2 * max_parallel
|
| 678 |
|
|
|
|
| 704 |
# Create combined auth dependency for document routes
|
| 705 |
combined_auth = get_combined_auth_dependency(api_key)
|
| 706 |
|
| 707 |
+
@router.post(
|
| 708 |
+
"/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)]
|
| 709 |
+
)
|
| 710 |
async def scan_for_new_documents(background_tasks: BackgroundTasks):
|
| 711 |
"""
|
| 712 |
Trigger the scanning process for new documents.
|
|
|
|
| 716 |
that fact.
|
| 717 |
|
| 718 |
Returns:
|
| 719 |
+
ScanResponse: A response object containing the scanning status
|
| 720 |
"""
|
| 721 |
# Start the scanning process in the background
|
| 722 |
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
| 723 |
+
return ScanResponse(
|
| 724 |
+
status="scanning_started",
|
| 725 |
+
message="Scanning process has been initiated in the background",
|
| 726 |
+
)
|
| 727 |
|
| 728 |
+
@router.post(
|
| 729 |
+
"/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
| 730 |
+
)
|
| 731 |
async def upload_to_input_dir(
|
| 732 |
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
| 733 |
):
|
|
|
|
| 847 |
logger.error(traceback.format_exc())
|
| 848 |
raise HTTPException(status_code=500, detail=str(e))
|
| 849 |
|
| 850 |
+
# TODO: deprecated, use /upload instead
|
| 851 |
@router.post(
|
| 852 |
"/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
|
| 853 |
)
|
|
|
|
| 891 |
logger.error(traceback.format_exc())
|
| 892 |
raise HTTPException(status_code=500, detail=str(e))
|
| 893 |
|
| 894 |
+
# TODO: deprecated, use /upload instead
|
| 895 |
@router.post(
|
| 896 |
"/file_batch",
|
| 897 |
response_model=InsertResponse,
|
|
|
|
| 956 |
raise HTTPException(status_code=500, detail=str(e))
|
| 957 |
|
| 958 |
@router.delete(
|
| 959 |
+
"", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)]
|
| 960 |
)
|
| 961 |
async def clear_documents():
|
| 962 |
"""
|
| 963 |
Clear all documents from the RAG system.
|
| 964 |
|
| 965 |
+
This endpoint deletes all documents, entities, relationships, and files from the system.
|
| 966 |
+
It uses the storage drop methods to properly clean up all data and removes all files
|
| 967 |
+
from the input directory.
|
| 968 |
|
| 969 |
Returns:
|
| 970 |
+
ClearDocumentsResponse: A response object containing the status and message.
|
| 971 |
+
- status="success": All documents and files were successfully cleared.
|
| 972 |
+
- status="partial_success": Document clear job exit with some errors.
|
| 973 |
+
- status="busy": Operation could not be completed because the pipeline is busy.
|
| 974 |
+
- status="fail": All storage drop operations failed, with message
|
| 975 |
+
- message: Detailed information about the operation results, including counts
|
| 976 |
+
of deleted files and any errors encountered.
|
| 977 |
|
| 978 |
Raises:
|
| 979 |
+
HTTPException: Raised when a serious error occurs during the clearing process,
|
| 980 |
+
with status code 500 and error details in the detail field.
|
| 981 |
"""
|
| 982 |
+
from lightrag.kg.shared_storage import (
|
| 983 |
+
get_namespace_data,
|
| 984 |
+
get_pipeline_status_lock,
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
# Get pipeline status and lock
|
| 988 |
+
pipeline_status = await get_namespace_data("pipeline_status")
|
| 989 |
+
pipeline_status_lock = get_pipeline_status_lock()
|
| 990 |
+
|
| 991 |
+
# Check and set status with lock
|
| 992 |
+
async with pipeline_status_lock:
|
| 993 |
+
if pipeline_status.get("busy", False):
|
| 994 |
+
return ClearDocumentsResponse(
|
| 995 |
+
status="busy",
|
| 996 |
+
message="Cannot clear documents while pipeline is busy",
|
| 997 |
+
)
|
| 998 |
+
# Set busy to true
|
| 999 |
+
pipeline_status.update(
|
| 1000 |
+
{
|
| 1001 |
+
"busy": True,
|
| 1002 |
+
"job_name": "Clearing Documents",
|
| 1003 |
+
"job_start": datetime.now().isoformat(),
|
| 1004 |
+
"docs": 0,
|
| 1005 |
+
"batchs": 0,
|
| 1006 |
+
"cur_batch": 0,
|
| 1007 |
+
"request_pending": False, # Clear any previous request
|
| 1008 |
+
"latest_message": "Starting document clearing process",
|
| 1009 |
+
}
|
| 1010 |
)
|
| 1011 |
+
# Cleaning history_messages without breaking it as a shared list object
|
| 1012 |
+
del pipeline_status["history_messages"][:]
|
| 1013 |
+
pipeline_status["history_messages"].append(
|
| 1014 |
+
"Starting document clearing process"
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
try:
|
| 1018 |
+
# Use drop method to clear all data
|
| 1019 |
+
drop_tasks = []
|
| 1020 |
+
storages = [
|
| 1021 |
+
rag.text_chunks,
|
| 1022 |
+
rag.full_docs,
|
| 1023 |
+
rag.entities_vdb,
|
| 1024 |
+
rag.relationships_vdb,
|
| 1025 |
+
rag.chunks_vdb,
|
| 1026 |
+
rag.chunk_entity_relation_graph,
|
| 1027 |
+
rag.doc_status,
|
| 1028 |
+
]
|
| 1029 |
+
|
| 1030 |
+
# Log storage drop start
|
| 1031 |
+
if "history_messages" in pipeline_status:
|
| 1032 |
+
pipeline_status["history_messages"].append(
|
| 1033 |
+
"Starting to drop storage components"
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
for storage in storages:
|
| 1037 |
+
if storage is not None:
|
| 1038 |
+
drop_tasks.append(storage.drop())
|
| 1039 |
+
|
| 1040 |
+
# Wait for all drop tasks to complete
|
| 1041 |
+
drop_results = await asyncio.gather(*drop_tasks, return_exceptions=True)
|
| 1042 |
+
|
| 1043 |
+
# Check for errors and log results
|
| 1044 |
+
errors = []
|
| 1045 |
+
storage_success_count = 0
|
| 1046 |
+
storage_error_count = 0
|
| 1047 |
+
|
| 1048 |
+
for i, result in enumerate(drop_results):
|
| 1049 |
+
storage_name = storages[i].__class__.__name__
|
| 1050 |
+
if isinstance(result, Exception):
|
| 1051 |
+
error_msg = f"Error dropping {storage_name}: {str(result)}"
|
| 1052 |
+
errors.append(error_msg)
|
| 1053 |
+
logger.error(error_msg)
|
| 1054 |
+
storage_error_count += 1
|
| 1055 |
+
else:
|
| 1056 |
+
logger.info(f"Successfully dropped {storage_name}")
|
| 1057 |
+
storage_success_count += 1
|
| 1058 |
+
|
| 1059 |
+
# Log storage drop results
|
| 1060 |
+
if "history_messages" in pipeline_status:
|
| 1061 |
+
if storage_error_count > 0:
|
| 1062 |
+
pipeline_status["history_messages"].append(
|
| 1063 |
+
f"Dropped {storage_success_count} storage components with {storage_error_count} errors"
|
| 1064 |
+
)
|
| 1065 |
+
else:
|
| 1066 |
+
pipeline_status["history_messages"].append(
|
| 1067 |
+
f"Successfully dropped all {storage_success_count} storage components"
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
# If all storage operations failed, return error status and don't proceed with file deletion
|
| 1071 |
+
if storage_success_count == 0 and storage_error_count > 0:
|
| 1072 |
+
error_message = "All storage drop operations failed. Aborting document clearing process."
|
| 1073 |
+
logger.error(error_message)
|
| 1074 |
+
if "history_messages" in pipeline_status:
|
| 1075 |
+
pipeline_status["history_messages"].append(error_message)
|
| 1076 |
+
return ClearDocumentsResponse(status="fail", message=error_message)
|
| 1077 |
+
|
| 1078 |
+
# Log file deletion start
|
| 1079 |
+
if "history_messages" in pipeline_status:
|
| 1080 |
+
pipeline_status["history_messages"].append(
|
| 1081 |
+
"Starting to delete files in input directory"
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
# Delete all files in input_dir
|
| 1085 |
+
deleted_files_count = 0
|
| 1086 |
+
file_errors_count = 0
|
| 1087 |
+
|
| 1088 |
+
for file_path in doc_manager.input_dir.glob("**/*"):
|
| 1089 |
+
if file_path.is_file():
|
| 1090 |
+
try:
|
| 1091 |
+
file_path.unlink()
|
| 1092 |
+
deleted_files_count += 1
|
| 1093 |
+
except Exception as e:
|
| 1094 |
+
logger.error(f"Error deleting file {file_path}: {str(e)}")
|
| 1095 |
+
file_errors_count += 1
|
| 1096 |
+
|
| 1097 |
+
# Log file deletion results
|
| 1098 |
+
if "history_messages" in pipeline_status:
|
| 1099 |
+
if file_errors_count > 0:
|
| 1100 |
+
pipeline_status["history_messages"].append(
|
| 1101 |
+
f"Deleted {deleted_files_count} files with {file_errors_count} errors"
|
| 1102 |
+
)
|
| 1103 |
+
errors.append(f"Failed to delete {file_errors_count} files")
|
| 1104 |
+
else:
|
| 1105 |
+
pipeline_status["history_messages"].append(
|
| 1106 |
+
f"Successfully deleted {deleted_files_count} files"
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
# Prepare final result message
|
| 1110 |
+
final_message = ""
|
| 1111 |
+
if errors:
|
| 1112 |
+
final_message = f"Cleared documents with some errors. Deleted {deleted_files_count} files."
|
| 1113 |
+
status = "partial_success"
|
| 1114 |
+
else:
|
| 1115 |
+
final_message = f"All documents cleared successfully. Deleted {deleted_files_count} files."
|
| 1116 |
+
status = "success"
|
| 1117 |
+
|
| 1118 |
+
# Log final result
|
| 1119 |
+
if "history_messages" in pipeline_status:
|
| 1120 |
+
pipeline_status["history_messages"].append(final_message)
|
| 1121 |
+
|
| 1122 |
+
# Return response based on results
|
| 1123 |
+
return ClearDocumentsResponse(status=status, message=final_message)
|
| 1124 |
except Exception as e:
|
| 1125 |
+
error_msg = f"Error clearing documents: {str(e)}"
|
| 1126 |
+
logger.error(error_msg)
|
| 1127 |
logger.error(traceback.format_exc())
|
| 1128 |
+
if "history_messages" in pipeline_status:
|
| 1129 |
+
pipeline_status["history_messages"].append(error_msg)
|
| 1130 |
raise HTTPException(status_code=500, detail=str(e))
|
| 1131 |
+
finally:
|
| 1132 |
+
# Reset busy status after completion
|
| 1133 |
+
async with pipeline_status_lock:
|
| 1134 |
+
pipeline_status["busy"] = False
|
| 1135 |
+
completion_msg = "Document clearing process completed"
|
| 1136 |
+
pipeline_status["latest_message"] = completion_msg
|
| 1137 |
+
if "history_messages" in pipeline_status:
|
| 1138 |
+
pipeline_status["history_messages"].append(completion_msg)
|
| 1139 |
|
| 1140 |
@router.get(
|
| 1141 |
"/pipeline_status",
|
|
|
|
| 1208 |
logger.error(traceback.format_exc())
|
| 1209 |
raise HTTPException(status_code=500, detail=str(e))
|
| 1210 |
|
| 1211 |
+
@router.get(
|
| 1212 |
+
"", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)]
|
| 1213 |
+
)
|
| 1214 |
async def documents() -> DocsStatusesResponse:
|
| 1215 |
"""
|
| 1216 |
Get the status of all documents in the system.
|
|
|
|
| 1268 |
logger.error(traceback.format_exc())
|
| 1269 |
raise HTTPException(status_code=500, detail=str(e))
|
| 1270 |
|
| 1271 |
+
@router.post(
|
| 1272 |
+
"/clear_cache",
|
| 1273 |
+
response_model=ClearCacheResponse,
|
| 1274 |
+
dependencies=[Depends(combined_auth)],
|
| 1275 |
+
)
|
| 1276 |
+
async def clear_cache(request: ClearCacheRequest):
|
| 1277 |
+
"""
|
| 1278 |
+
Clear cache data from the LLM response cache storage.
|
| 1279 |
+
|
| 1280 |
+
This endpoint allows clearing specific modes of cache or all cache if no modes are specified.
|
| 1281 |
+
Valid modes include: "default", "naive", "local", "global", "hybrid", "mix".
|
| 1282 |
+
- "default" represents extraction cache.
|
| 1283 |
+
- Other modes correspond to different query modes.
|
| 1284 |
+
|
| 1285 |
+
Args:
|
| 1286 |
+
request (ClearCacheRequest): The request body containing optional modes to clear.
|
| 1287 |
+
|
| 1288 |
+
Returns:
|
| 1289 |
+
ClearCacheResponse: A response object containing the status and message.
|
| 1290 |
+
|
| 1291 |
+
Raises:
|
| 1292 |
+
HTTPException: If an error occurs during cache clearing (400 for invalid modes, 500 for other errors).
|
| 1293 |
+
"""
|
| 1294 |
+
try:
|
| 1295 |
+
# Validate modes if provided
|
| 1296 |
+
valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
|
| 1297 |
+
if request.modes and not all(mode in valid_modes for mode in request.modes):
|
| 1298 |
+
invalid_modes = [
|
| 1299 |
+
mode for mode in request.modes if mode not in valid_modes
|
| 1300 |
+
]
|
| 1301 |
+
raise HTTPException(
|
| 1302 |
+
status_code=400,
|
| 1303 |
+
detail=f"Invalid mode(s): {invalid_modes}. Valid modes are: {valid_modes}",
|
| 1304 |
+
)
|
| 1305 |
+
|
| 1306 |
+
# Call the aclear_cache method
|
| 1307 |
+
await rag.aclear_cache(request.modes)
|
| 1308 |
+
|
| 1309 |
+
# Prepare success message
|
| 1310 |
+
if request.modes:
|
| 1311 |
+
message = f"Successfully cleared cache for modes: {request.modes}"
|
| 1312 |
+
else:
|
| 1313 |
+
message = "Successfully cleared all cache"
|
| 1314 |
+
|
| 1315 |
+
return ClearCacheResponse(status="success", message=message)
|
| 1316 |
+
except HTTPException:
|
| 1317 |
+
# Re-raise HTTP exceptions
|
| 1318 |
+
raise
|
| 1319 |
+
except Exception as e:
|
| 1320 |
+
logger.error(f"Error clearing cache: {str(e)}")
|
| 1321 |
+
logger.error(traceback.format_exc())
|
| 1322 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1323 |
+
|
| 1324 |
return router
|
lightrag/api/routers/graph_routes.py
CHANGED
|
@@ -3,7 +3,7 @@ This module contains all graph-related routes for the LightRAG API.
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
from typing import Optional
|
| 6 |
-
from fastapi import APIRouter, Depends
|
| 7 |
|
| 8 |
from ..utils_api import get_combined_auth_dependency
|
| 9 |
|
|
@@ -25,23 +25,20 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|
| 25 |
|
| 26 |
@router.get("/graphs", dependencies=[Depends(combined_auth)])
|
| 27 |
async def get_knowledge_graph(
|
| 28 |
-
label: str
|
|
|
|
|
|
|
| 29 |
):
|
| 30 |
"""
|
| 31 |
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
| 32 |
-
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
| 33 |
When reducing the number of nodes, the prioritization criteria are as follows:
|
| 34 |
-
1.
|
| 35 |
-
2.
|
| 36 |
-
3. Followed by nodes directly connected to the matching nodes
|
| 37 |
-
4. Finally, the degree of the nodes
|
| 38 |
-
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
|
| 39 |
|
| 40 |
Args:
|
| 41 |
-
label (str): Label
|
| 42 |
-
max_depth (int, optional): Maximum depth of
|
| 43 |
-
|
| 44 |
-
min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
|
| 45 |
|
| 46 |
Returns:
|
| 47 |
Dict[str, List[str]]: Knowledge graph for label
|
|
@@ -49,8 +46,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|
| 49 |
return await rag.get_knowledge_graph(
|
| 50 |
node_label=label,
|
| 51 |
max_depth=max_depth,
|
| 52 |
-
|
| 53 |
-
min_degree=min_degree,
|
| 54 |
)
|
| 55 |
|
| 56 |
return router
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
from typing import Optional
|
| 6 |
+
from fastapi import APIRouter, Depends, Query
|
| 7 |
|
| 8 |
from ..utils_api import get_combined_auth_dependency
|
| 9 |
|
|
|
|
| 25 |
|
| 26 |
@router.get("/graphs", dependencies=[Depends(combined_auth)])
|
| 27 |
async def get_knowledge_graph(
|
| 28 |
+
label: str = Query(..., description="Label to get knowledge graph for"),
|
| 29 |
+
max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
|
| 30 |
+
max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
|
| 31 |
):
|
| 32 |
"""
|
| 33 |
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
|
|
|
| 34 |
When reducing the number of nodes, the prioritization criteria are as follows:
|
| 35 |
+
1. Hops(path) to the staring node take precedence
|
| 36 |
+
2. Followed by the degree of the nodes
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
Args:
|
| 39 |
+
label (str): Label of the starting node
|
| 40 |
+
max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3
|
| 41 |
+
max_nodes: Maxiumu nodes to return
|
|
|
|
| 42 |
|
| 43 |
Returns:
|
| 44 |
Dict[str, List[str]]: Knowledge graph for label
|
|
|
|
| 46 |
return await rag.get_knowledge_graph(
|
| 47 |
node_label=label,
|
| 48 |
max_depth=max_depth,
|
| 49 |
+
max_nodes=max_nodes,
|
|
|
|
| 50 |
)
|
| 51 |
|
| 52 |
return router
|
lightrag/api/run_with_gunicorn.py
CHANGED
|
@@ -7,14 +7,9 @@ import os
|
|
| 7 |
import sys
|
| 8 |
import signal
|
| 9 |
import pipmaster as pm
|
| 10 |
-
from lightrag.api.utils_api import
|
| 11 |
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
|
| 12 |
-
from
|
| 13 |
-
|
| 14 |
-
# use the .env that is inside the current folder
|
| 15 |
-
# allows to use different .env file for each lightrag instance
|
| 16 |
-
# the OS environment variables take precedence over the .env file
|
| 17 |
-
load_dotenv(dotenv_path=".env", override=False)
|
| 18 |
|
| 19 |
|
| 20 |
def check_and_install_dependencies():
|
|
@@ -59,20 +54,17 @@ def main():
|
|
| 59 |
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
| 60 |
signal.signal(signal.SIGTERM, signal_handler) # kill command
|
| 61 |
|
| 62 |
-
# Parse all arguments using parse_args
|
| 63 |
-
args = parse_args(is_uvicorn_mode=False)
|
| 64 |
-
|
| 65 |
# Display startup information
|
| 66 |
-
display_splash_screen(
|
| 67 |
|
| 68 |
print("🚀 Starting LightRAG with Gunicorn")
|
| 69 |
-
print(f"🔄 Worker management: Gunicorn (workers={
|
| 70 |
print("🔍 Preloading app: Enabled")
|
| 71 |
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
|
| 72 |
print("\n\n" + "=" * 80)
|
| 73 |
print("MAIN PROCESS INITIALIZATION")
|
| 74 |
print(f"Process ID: {os.getpid()}")
|
| 75 |
-
print(f"Workers setting: {
|
| 76 |
print("=" * 80 + "\n")
|
| 77 |
|
| 78 |
# Import Gunicorn's StandaloneApplication
|
|
@@ -128,31 +120,43 @@ def main():
|
|
| 128 |
|
| 129 |
# Set configuration variables in gunicorn_config, prioritizing command line arguments
|
| 130 |
gunicorn_config.workers = (
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
)
|
| 133 |
|
| 134 |
# Bind configuration prioritizes command line arguments
|
| 135 |
-
host =
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
gunicorn_config.bind = f"{host}:{port}"
|
| 138 |
|
| 139 |
# Log level configuration prioritizes command line arguments
|
| 140 |
gunicorn_config.loglevel = (
|
| 141 |
-
|
| 142 |
-
if
|
| 143 |
else os.getenv("LOG_LEVEL", "info")
|
| 144 |
)
|
| 145 |
|
| 146 |
# Timeout configuration prioritizes command line arguments
|
| 147 |
gunicorn_config.timeout = (
|
| 148 |
-
|
|
|
|
|
|
|
| 149 |
)
|
| 150 |
|
| 151 |
# Keepalive configuration
|
| 152 |
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
|
| 153 |
|
| 154 |
# SSL configuration prioritizes command line arguments
|
| 155 |
-
if
|
| 156 |
"true",
|
| 157 |
"1",
|
| 158 |
"yes",
|
|
@@ -160,12 +164,14 @@ def main():
|
|
| 160 |
"on",
|
| 161 |
):
|
| 162 |
gunicorn_config.certfile = (
|
| 163 |
-
|
| 164 |
-
if
|
| 165 |
else os.getenv("SSL_CERTFILE")
|
| 166 |
)
|
| 167 |
gunicorn_config.keyfile = (
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
)
|
| 170 |
|
| 171 |
# Set configuration options from the module
|
|
@@ -190,13 +196,13 @@ def main():
|
|
| 190 |
# Import the application
|
| 191 |
from lightrag.api.lightrag_server import get_application
|
| 192 |
|
| 193 |
-
return get_application(
|
| 194 |
|
| 195 |
# Create the application
|
| 196 |
app = GunicornApp("")
|
| 197 |
|
| 198 |
# Force workers to be an integer and greater than 1 for multi-process mode
|
| 199 |
-
workers_count = int(
|
| 200 |
if workers_count > 1:
|
| 201 |
# Set a flag to indicate we're in the main process
|
| 202 |
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
|
|
|
|
| 7 |
import sys
|
| 8 |
import signal
|
| 9 |
import pipmaster as pm
|
| 10 |
+
from lightrag.api.utils_api import display_splash_screen, check_env_file
|
| 11 |
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
|
| 12 |
+
from .config import global_args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def check_and_install_dependencies():
|
|
|
|
| 54 |
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
| 55 |
signal.signal(signal.SIGTERM, signal_handler) # kill command
|
| 56 |
|
|
|
|
|
|
|
|
|
|
| 57 |
# Display startup information
|
| 58 |
+
display_splash_screen(global_args)
|
| 59 |
|
| 60 |
print("🚀 Starting LightRAG with Gunicorn")
|
| 61 |
+
print(f"🔄 Worker management: Gunicorn (workers={global_args.workers})")
|
| 62 |
print("🔍 Preloading app: Enabled")
|
| 63 |
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
|
| 64 |
print("\n\n" + "=" * 80)
|
| 65 |
print("MAIN PROCESS INITIALIZATION")
|
| 66 |
print(f"Process ID: {os.getpid()}")
|
| 67 |
+
print(f"Workers setting: {global_args.workers}")
|
| 68 |
print("=" * 80 + "\n")
|
| 69 |
|
| 70 |
# Import Gunicorn's StandaloneApplication
|
|
|
|
| 120 |
|
| 121 |
# Set configuration variables in gunicorn_config, prioritizing command line arguments
|
| 122 |
gunicorn_config.workers = (
|
| 123 |
+
global_args.workers
|
| 124 |
+
if global_args.workers
|
| 125 |
+
else int(os.getenv("WORKERS", 1))
|
| 126 |
)
|
| 127 |
|
| 128 |
# Bind configuration prioritizes command line arguments
|
| 129 |
+
host = (
|
| 130 |
+
global_args.host
|
| 131 |
+
if global_args.host != "0.0.0.0"
|
| 132 |
+
else os.getenv("HOST", "0.0.0.0")
|
| 133 |
+
)
|
| 134 |
+
port = (
|
| 135 |
+
global_args.port
|
| 136 |
+
if global_args.port != 9621
|
| 137 |
+
else int(os.getenv("PORT", 9621))
|
| 138 |
+
)
|
| 139 |
gunicorn_config.bind = f"{host}:{port}"
|
| 140 |
|
| 141 |
# Log level configuration prioritizes command line arguments
|
| 142 |
gunicorn_config.loglevel = (
|
| 143 |
+
global_args.log_level.lower()
|
| 144 |
+
if global_args.log_level
|
| 145 |
else os.getenv("LOG_LEVEL", "info")
|
| 146 |
)
|
| 147 |
|
| 148 |
# Timeout configuration prioritizes command line arguments
|
| 149 |
gunicorn_config.timeout = (
|
| 150 |
+
global_args.timeout
|
| 151 |
+
if global_args.timeout * 2
|
| 152 |
+
else int(os.getenv("TIMEOUT", 150 * 2))
|
| 153 |
)
|
| 154 |
|
| 155 |
# Keepalive configuration
|
| 156 |
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
|
| 157 |
|
| 158 |
# SSL configuration prioritizes command line arguments
|
| 159 |
+
if global_args.ssl or os.getenv("SSL", "").lower() in (
|
| 160 |
"true",
|
| 161 |
"1",
|
| 162 |
"yes",
|
|
|
|
| 164 |
"on",
|
| 165 |
):
|
| 166 |
gunicorn_config.certfile = (
|
| 167 |
+
global_args.ssl_certfile
|
| 168 |
+
if global_args.ssl_certfile
|
| 169 |
else os.getenv("SSL_CERTFILE")
|
| 170 |
)
|
| 171 |
gunicorn_config.keyfile = (
|
| 172 |
+
global_args.ssl_keyfile
|
| 173 |
+
if global_args.ssl_keyfile
|
| 174 |
+
else os.getenv("SSL_KEYFILE")
|
| 175 |
)
|
| 176 |
|
| 177 |
# Set configuration options from the module
|
|
|
|
| 196 |
# Import the application
|
| 197 |
from lightrag.api.lightrag_server import get_application
|
| 198 |
|
| 199 |
+
return get_application(global_args)
|
| 200 |
|
| 201 |
# Create the application
|
| 202 |
app = GunicornApp("")
|
| 203 |
|
| 204 |
# Force workers to be an integer and greater than 1 for multi-process mode
|
| 205 |
+
workers_count = int(global_args.workers)
|
| 206 |
if workers_count > 1:
|
| 207 |
# Set a flag to indicate we're in the main process
|
| 208 |
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
|
lightrag/api/utils_api.py
CHANGED
|
@@ -7,15 +7,13 @@ import argparse
|
|
| 7 |
from typing import Optional, List, Tuple
|
| 8 |
import sys
|
| 9 |
from ascii_colors import ASCIIColors
|
| 10 |
-
import logging
|
| 11 |
from lightrag.api import __api_version__ as api_version
|
| 12 |
from lightrag import __version__ as core_version
|
| 13 |
from fastapi import HTTPException, Security, Request, status
|
| 14 |
-
from dotenv import load_dotenv
|
| 15 |
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
| 16 |
from starlette.status import HTTP_403_FORBIDDEN
|
| 17 |
from .auth import auth_handler
|
| 18 |
-
from
|
| 19 |
|
| 20 |
|
| 21 |
def check_env_file():
|
|
@@ -36,16 +34,8 @@ def check_env_file():
|
|
| 36 |
return True
|
| 37 |
|
| 38 |
|
| 39 |
-
#
|
| 40 |
-
|
| 41 |
-
# the OS environment variables take precedence over the .env file
|
| 42 |
-
load_dotenv(dotenv_path=".env", override=False)
|
| 43 |
-
|
| 44 |
-
global_args = {"main_args": None}
|
| 45 |
-
|
| 46 |
-
# Get whitelist paths from environment variable, only once during initialization
|
| 47 |
-
default_whitelist = "/health,/api/*"
|
| 48 |
-
whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
|
| 49 |
|
| 50 |
# Pre-compile path matching patterns
|
| 51 |
whitelist_patterns: List[Tuple[str, bool]] = []
|
|
@@ -63,19 +53,6 @@ for path in whitelist_paths:
|
|
| 63 |
auth_configured = bool(auth_handler.accounts)
|
| 64 |
|
| 65 |
|
| 66 |
-
class OllamaServerInfos:
|
| 67 |
-
# Constants for emulated Ollama model information
|
| 68 |
-
LIGHTRAG_NAME = "lightrag"
|
| 69 |
-
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
| 70 |
-
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
| 71 |
-
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
| 72 |
-
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
| 73 |
-
LIGHTRAG_DIGEST = "sha256:lightrag"
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
ollama_server_infos = OllamaServerInfos()
|
| 77 |
-
|
| 78 |
-
|
| 79 |
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
| 80 |
"""
|
| 81 |
Create a combined authentication dependency that implements authentication logic
|
|
@@ -186,299 +163,6 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
|
| 186 |
return combined_dependency
|
| 187 |
|
| 188 |
|
| 189 |
-
class DefaultRAGStorageConfig:
|
| 190 |
-
KV_STORAGE = "JsonKVStorage"
|
| 191 |
-
VECTOR_STORAGE = "NanoVectorDBStorage"
|
| 192 |
-
GRAPH_STORAGE = "NetworkXStorage"
|
| 193 |
-
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
def get_default_host(binding_type: str) -> str:
|
| 197 |
-
default_hosts = {
|
| 198 |
-
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
| 199 |
-
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
| 200 |
-
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
| 201 |
-
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
| 202 |
-
}
|
| 203 |
-
return default_hosts.get(
|
| 204 |
-
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
| 205 |
-
) # fallback to ollama if unknown
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
|
| 209 |
-
"""
|
| 210 |
-
Get value from environment variable with type conversion
|
| 211 |
-
|
| 212 |
-
Args:
|
| 213 |
-
env_key (str): Environment variable key
|
| 214 |
-
default (any): Default value if env variable is not set
|
| 215 |
-
value_type (type): Type to convert the value to
|
| 216 |
-
|
| 217 |
-
Returns:
|
| 218 |
-
any: Converted value from environment or default
|
| 219 |
-
"""
|
| 220 |
-
value = os.getenv(env_key)
|
| 221 |
-
if value is None:
|
| 222 |
-
return default
|
| 223 |
-
|
| 224 |
-
if value_type is bool:
|
| 225 |
-
return value.lower() in ("true", "1", "yes", "t", "on")
|
| 226 |
-
try:
|
| 227 |
-
return value_type(value)
|
| 228 |
-
except ValueError:
|
| 229 |
-
return default
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
| 233 |
-
"""
|
| 234 |
-
Parse command line arguments with environment variable fallback
|
| 235 |
-
|
| 236 |
-
Args:
|
| 237 |
-
is_uvicorn_mode: Whether running under uvicorn mode
|
| 238 |
-
|
| 239 |
-
Returns:
|
| 240 |
-
argparse.Namespace: Parsed arguments
|
| 241 |
-
"""
|
| 242 |
-
|
| 243 |
-
parser = argparse.ArgumentParser(
|
| 244 |
-
description="LightRAG FastAPI Server with separate working and input directories"
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
-
# Server configuration
|
| 248 |
-
parser.add_argument(
|
| 249 |
-
"--host",
|
| 250 |
-
default=get_env_value("HOST", "0.0.0.0"),
|
| 251 |
-
help="Server host (default: from env or 0.0.0.0)",
|
| 252 |
-
)
|
| 253 |
-
parser.add_argument(
|
| 254 |
-
"--port",
|
| 255 |
-
type=int,
|
| 256 |
-
default=get_env_value("PORT", 9621, int),
|
| 257 |
-
help="Server port (default: from env or 9621)",
|
| 258 |
-
)
|
| 259 |
-
|
| 260 |
-
# Directory configuration
|
| 261 |
-
parser.add_argument(
|
| 262 |
-
"--working-dir",
|
| 263 |
-
default=get_env_value("WORKING_DIR", "./rag_storage"),
|
| 264 |
-
help="Working directory for RAG storage (default: from env or ./rag_storage)",
|
| 265 |
-
)
|
| 266 |
-
parser.add_argument(
|
| 267 |
-
"--input-dir",
|
| 268 |
-
default=get_env_value("INPUT_DIR", "./inputs"),
|
| 269 |
-
help="Directory containing input documents (default: from env or ./inputs)",
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
def timeout_type(value):
|
| 273 |
-
if value is None:
|
| 274 |
-
return 150
|
| 275 |
-
if value is None or value == "None":
|
| 276 |
-
return None
|
| 277 |
-
return int(value)
|
| 278 |
-
|
| 279 |
-
parser.add_argument(
|
| 280 |
-
"--timeout",
|
| 281 |
-
default=get_env_value("TIMEOUT", None, timeout_type),
|
| 282 |
-
type=timeout_type,
|
| 283 |
-
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
|
| 284 |
-
)
|
| 285 |
-
|
| 286 |
-
# RAG configuration
|
| 287 |
-
parser.add_argument(
|
| 288 |
-
"--max-async",
|
| 289 |
-
type=int,
|
| 290 |
-
default=get_env_value("MAX_ASYNC", 4, int),
|
| 291 |
-
help="Maximum async operations (default: from env or 4)",
|
| 292 |
-
)
|
| 293 |
-
parser.add_argument(
|
| 294 |
-
"--max-tokens",
|
| 295 |
-
type=int,
|
| 296 |
-
default=get_env_value("MAX_TOKENS", 32768, int),
|
| 297 |
-
help="Maximum token size (default: from env or 32768)",
|
| 298 |
-
)
|
| 299 |
-
|
| 300 |
-
# Logging configuration
|
| 301 |
-
parser.add_argument(
|
| 302 |
-
"--log-level",
|
| 303 |
-
default=get_env_value("LOG_LEVEL", "INFO"),
|
| 304 |
-
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
| 305 |
-
help="Logging level (default: from env or INFO)",
|
| 306 |
-
)
|
| 307 |
-
parser.add_argument(
|
| 308 |
-
"--verbose",
|
| 309 |
-
action="store_true",
|
| 310 |
-
default=get_env_value("VERBOSE", False, bool),
|
| 311 |
-
help="Enable verbose debug output(only valid for DEBUG log-level)",
|
| 312 |
-
)
|
| 313 |
-
|
| 314 |
-
parser.add_argument(
|
| 315 |
-
"--key",
|
| 316 |
-
type=str,
|
| 317 |
-
default=get_env_value("LIGHTRAG_API_KEY", None),
|
| 318 |
-
help="API key for authentication. This protects lightrag server against unauthorized access",
|
| 319 |
-
)
|
| 320 |
-
|
| 321 |
-
# Optional https parameters
|
| 322 |
-
parser.add_argument(
|
| 323 |
-
"--ssl",
|
| 324 |
-
action="store_true",
|
| 325 |
-
default=get_env_value("SSL", False, bool),
|
| 326 |
-
help="Enable HTTPS (default: from env or False)",
|
| 327 |
-
)
|
| 328 |
-
parser.add_argument(
|
| 329 |
-
"--ssl-certfile",
|
| 330 |
-
default=get_env_value("SSL_CERTFILE", None),
|
| 331 |
-
help="Path to SSL certificate file (required if --ssl is enabled)",
|
| 332 |
-
)
|
| 333 |
-
parser.add_argument(
|
| 334 |
-
"--ssl-keyfile",
|
| 335 |
-
default=get_env_value("SSL_KEYFILE", None),
|
| 336 |
-
help="Path to SSL private key file (required if --ssl is enabled)",
|
| 337 |
-
)
|
| 338 |
-
|
| 339 |
-
parser.add_argument(
|
| 340 |
-
"--history-turns",
|
| 341 |
-
type=int,
|
| 342 |
-
default=get_env_value("HISTORY_TURNS", 3, int),
|
| 343 |
-
help="Number of conversation history turns to include (default: from env or 3)",
|
| 344 |
-
)
|
| 345 |
-
|
| 346 |
-
# Search parameters
|
| 347 |
-
parser.add_argument(
|
| 348 |
-
"--top-k",
|
| 349 |
-
type=int,
|
| 350 |
-
default=get_env_value("TOP_K", 60, int),
|
| 351 |
-
help="Number of most similar results to return (default: from env or 60)",
|
| 352 |
-
)
|
| 353 |
-
parser.add_argument(
|
| 354 |
-
"--cosine-threshold",
|
| 355 |
-
type=float,
|
| 356 |
-
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
|
| 357 |
-
help="Cosine similarity threshold (default: from env or 0.4)",
|
| 358 |
-
)
|
| 359 |
-
|
| 360 |
-
# Ollama model name
|
| 361 |
-
parser.add_argument(
|
| 362 |
-
"--simulated-model-name",
|
| 363 |
-
type=str,
|
| 364 |
-
default=get_env_value(
|
| 365 |
-
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
|
| 366 |
-
),
|
| 367 |
-
help="Number of conversation history turns to include (default: from env or 3)",
|
| 368 |
-
)
|
| 369 |
-
|
| 370 |
-
# Namespace
|
| 371 |
-
parser.add_argument(
|
| 372 |
-
"--namespace-prefix",
|
| 373 |
-
type=str,
|
| 374 |
-
default=get_env_value("NAMESPACE_PREFIX", ""),
|
| 375 |
-
help="Prefix of the namespace",
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
-
parser.add_argument(
|
| 379 |
-
"--auto-scan-at-startup",
|
| 380 |
-
action="store_true",
|
| 381 |
-
default=False,
|
| 382 |
-
help="Enable automatic scanning when the program starts",
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
# Server workers configuration
|
| 386 |
-
parser.add_argument(
|
| 387 |
-
"--workers",
|
| 388 |
-
type=int,
|
| 389 |
-
default=get_env_value("WORKERS", 1, int),
|
| 390 |
-
help="Number of worker processes (default: from env or 1)",
|
| 391 |
-
)
|
| 392 |
-
|
| 393 |
-
# LLM and embedding bindings
|
| 394 |
-
parser.add_argument(
|
| 395 |
-
"--llm-binding",
|
| 396 |
-
type=str,
|
| 397 |
-
default=get_env_value("LLM_BINDING", "ollama"),
|
| 398 |
-
choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
|
| 399 |
-
help="LLM binding type (default: from env or ollama)",
|
| 400 |
-
)
|
| 401 |
-
parser.add_argument(
|
| 402 |
-
"--embedding-binding",
|
| 403 |
-
type=str,
|
| 404 |
-
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
| 405 |
-
choices=["lollms", "ollama", "openai", "azure_openai"],
|
| 406 |
-
help="Embedding binding type (default: from env or ollama)",
|
| 407 |
-
)
|
| 408 |
-
|
| 409 |
-
args = parser.parse_args()
|
| 410 |
-
|
| 411 |
-
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
| 412 |
-
if is_uvicorn_mode and args.workers > 1:
|
| 413 |
-
original_workers = args.workers
|
| 414 |
-
args.workers = 1
|
| 415 |
-
# Log warning directly here
|
| 416 |
-
logging.warning(
|
| 417 |
-
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
|
| 418 |
-
)
|
| 419 |
-
|
| 420 |
-
# convert relative path to absolute path
|
| 421 |
-
args.working_dir = os.path.abspath(args.working_dir)
|
| 422 |
-
args.input_dir = os.path.abspath(args.input_dir)
|
| 423 |
-
|
| 424 |
-
# Inject storage configuration from environment variables
|
| 425 |
-
args.kv_storage = get_env_value(
|
| 426 |
-
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
|
| 427 |
-
)
|
| 428 |
-
args.doc_status_storage = get_env_value(
|
| 429 |
-
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
|
| 430 |
-
)
|
| 431 |
-
args.graph_storage = get_env_value(
|
| 432 |
-
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
|
| 433 |
-
)
|
| 434 |
-
args.vector_storage = get_env_value(
|
| 435 |
-
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
|
| 436 |
-
)
|
| 437 |
-
|
| 438 |
-
# Get MAX_PARALLEL_INSERT from environment
|
| 439 |
-
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
|
| 440 |
-
|
| 441 |
-
# Handle openai-ollama special case
|
| 442 |
-
if args.llm_binding == "openai-ollama":
|
| 443 |
-
args.llm_binding = "openai"
|
| 444 |
-
args.embedding_binding = "ollama"
|
| 445 |
-
|
| 446 |
-
args.llm_binding_host = get_env_value(
|
| 447 |
-
"LLM_BINDING_HOST", get_default_host(args.llm_binding)
|
| 448 |
-
)
|
| 449 |
-
args.embedding_binding_host = get_env_value(
|
| 450 |
-
"EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
|
| 451 |
-
)
|
| 452 |
-
args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
|
| 453 |
-
args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
|
| 454 |
-
|
| 455 |
-
# Inject model configuration
|
| 456 |
-
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
|
| 457 |
-
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
|
| 458 |
-
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
|
| 459 |
-
args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
|
| 460 |
-
|
| 461 |
-
# Inject chunk configuration
|
| 462 |
-
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
| 463 |
-
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
| 464 |
-
|
| 465 |
-
# Inject LLM cache configuration
|
| 466 |
-
args.enable_llm_cache_for_extract = get_env_value(
|
| 467 |
-
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
|
| 468 |
-
)
|
| 469 |
-
|
| 470 |
-
# Inject LLM temperature configuration
|
| 471 |
-
args.temperature = get_env_value("TEMPERATURE", 0.5, float)
|
| 472 |
-
|
| 473 |
-
# Select Document loading tool (DOCLING, DEFAULT)
|
| 474 |
-
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
| 475 |
-
|
| 476 |
-
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
| 477 |
-
|
| 478 |
-
global_args["main_args"] = args
|
| 479 |
-
return args
|
| 480 |
-
|
| 481 |
-
|
| 482 |
def display_splash_screen(args: argparse.Namespace) -> None:
|
| 483 |
"""
|
| 484 |
Display a colorful splash screen showing LightRAG server configuration
|
|
@@ -489,7 +173,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
| 489 |
# Banner
|
| 490 |
ASCIIColors.cyan(f"""
|
| 491 |
╔══════════════════════════════════════════════════════════════╗
|
| 492 |
-
║
|
| 493 |
║ Fast, Lightweight RAG Server Implementation ║
|
| 494 |
╚══════════════════════════════════════════════════════════════╝
|
| 495 |
""")
|
|
@@ -503,7 +187,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
| 503 |
ASCIIColors.white(" ├─ Workers: ", end="")
|
| 504 |
ASCIIColors.yellow(f"{args.workers}")
|
| 505 |
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
| 506 |
-
ASCIIColors.yellow(f"{
|
| 507 |
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
| 508 |
ASCIIColors.yellow(f"{args.ssl}")
|
| 509 |
if args.ssl:
|
|
@@ -519,8 +203,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
| 519 |
ASCIIColors.yellow(f"{args.verbose}")
|
| 520 |
ASCIIColors.white(" ├─ History Turns: ", end="")
|
| 521 |
ASCIIColors.yellow(f"{args.history_turns}")
|
| 522 |
-
ASCIIColors.white("
|
| 523 |
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
|
|
|
|
|
|
| 524 |
|
| 525 |
# Directory Configuration
|
| 526 |
ASCIIColors.magenta("\n📂 Directory Configuration:")
|
|
@@ -558,10 +244,9 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
| 558 |
ASCIIColors.yellow(f"{args.embedding_dim}")
|
| 559 |
|
| 560 |
# RAG Configuration
|
| 561 |
-
summary_language = os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"])
|
| 562 |
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
|
| 563 |
ASCIIColors.white(" ├─ Summary Language: ", end="")
|
| 564 |
-
ASCIIColors.yellow(f"{summary_language}")
|
| 565 |
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
|
| 566 |
ASCIIColors.yellow(f"{args.max_parallel_insert}")
|
| 567 |
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
|
|
@@ -595,19 +280,17 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
| 595 |
protocol = "https" if args.ssl else "http"
|
| 596 |
if args.host == "0.0.0.0":
|
| 597 |
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
| 598 |
-
ASCIIColors.white(" ├─
|
| 599 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
|
| 600 |
ASCIIColors.white(" ├─ Remote Access: ", end="")
|
| 601 |
ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
|
| 602 |
ASCIIColors.white(" ├─ API Documentation (local): ", end="")
|
| 603 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
|
| 604 |
-
ASCIIColors.white("
|
| 605 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
|
| 606 |
-
ASCIIColors.white(" └─ WebUI (local): ", end="")
|
| 607 |
-
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui")
|
| 608 |
|
| 609 |
-
ASCIIColors.
|
| 610 |
-
ASCIIColors.
|
| 611 |
- Use 'localhost' or '127.0.0.1' for local access
|
| 612 |
- Use your machine's IP address for remote access
|
| 613 |
- To find your IP address:
|
|
@@ -617,42 +300,24 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|
| 617 |
else:
|
| 618 |
base_url = f"{protocol}://{args.host}:{args.port}"
|
| 619 |
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
| 620 |
-
ASCIIColors.white(" ├─
|
| 621 |
ASCIIColors.yellow(f"{base_url}")
|
| 622 |
ASCIIColors.white(" ├─ API Documentation: ", end="")
|
| 623 |
ASCIIColors.yellow(f"{base_url}/docs")
|
| 624 |
ASCIIColors.white(" └─ Alternative Documentation: ", end="")
|
| 625 |
ASCIIColors.yellow(f"{base_url}/redoc")
|
| 626 |
|
| 627 |
-
# Usage Examples
|
| 628 |
-
ASCIIColors.magenta("\n📚 Quick Start Guide:")
|
| 629 |
-
ASCIIColors.cyan("""
|
| 630 |
-
1. Access the Swagger UI:
|
| 631 |
-
Open your browser and navigate to the API documentation URL above
|
| 632 |
-
|
| 633 |
-
2. API Authentication:""")
|
| 634 |
-
if args.key:
|
| 635 |
-
ASCIIColors.cyan(""" Add the following header to your requests:
|
| 636 |
-
X-API-Key: <your-api-key>
|
| 637 |
-
""")
|
| 638 |
-
else:
|
| 639 |
-
ASCIIColors.cyan(" No authentication required\n")
|
| 640 |
-
|
| 641 |
-
ASCIIColors.cyan(""" 3. Basic Operations:
|
| 642 |
-
- POST /upload_document: Upload new documents to RAG
|
| 643 |
-
- POST /query: Query your document collection
|
| 644 |
-
|
| 645 |
-
4. Monitor the server:
|
| 646 |
-
- Check server logs for detailed operation information
|
| 647 |
-
- Use healthcheck endpoint: GET /health
|
| 648 |
-
""")
|
| 649 |
-
|
| 650 |
# Security Notice
|
| 651 |
if args.key:
|
| 652 |
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
| 653 |
ASCIIColors.white(""" API Key authentication is enabled.
|
| 654 |
Make sure to include the X-API-Key header in all your requests.
|
| 655 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
|
| 657 |
# Ensure splash output flush to system log
|
| 658 |
sys.stdout.flush()
|
|
|
|
| 7 |
from typing import Optional, List, Tuple
|
| 8 |
import sys
|
| 9 |
from ascii_colors import ASCIIColors
|
|
|
|
| 10 |
from lightrag.api import __api_version__ as api_version
|
| 11 |
from lightrag import __version__ as core_version
|
| 12 |
from fastapi import HTTPException, Security, Request, status
|
|
|
|
| 13 |
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
| 14 |
from starlette.status import HTTP_403_FORBIDDEN
|
| 15 |
from .auth import auth_handler
|
| 16 |
+
from .config import ollama_server_infos, global_args
|
| 17 |
|
| 18 |
|
| 19 |
def check_env_file():
|
|
|
|
| 34 |
return True
|
| 35 |
|
| 36 |
|
| 37 |
+
# Get whitelist paths from global_args, only once during initialization
|
| 38 |
+
whitelist_paths = global_args.whitelist_paths.split(",")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# Pre-compile path matching patterns
|
| 41 |
whitelist_patterns: List[Tuple[str, bool]] = []
|
|
|
|
| 53 |
auth_configured = bool(auth_handler.accounts)
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
| 57 |
"""
|
| 58 |
Create a combined authentication dependency that implements authentication logic
|
|
|
|
| 163 |
return combined_dependency
|
| 164 |
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
def display_splash_screen(args: argparse.Namespace) -> None:
|
| 167 |
"""
|
| 168 |
Display a colorful splash screen showing LightRAG server configuration
|
|
|
|
| 173 |
# Banner
|
| 174 |
ASCIIColors.cyan(f"""
|
| 175 |
╔══════════════════════════════════════════════════════════════╗
|
| 176 |
+
║ 🚀 LightRAG Server v{core_version}/{api_version} ║
|
| 177 |
║ Fast, Lightweight RAG Server Implementation ║
|
| 178 |
╚══════════════════════════════════════════════════════════════╝
|
| 179 |
""")
|
|
|
|
| 187 |
ASCIIColors.white(" ├─ Workers: ", end="")
|
| 188 |
ASCIIColors.yellow(f"{args.workers}")
|
| 189 |
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
| 190 |
+
ASCIIColors.yellow(f"{args.cors_origins}")
|
| 191 |
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
| 192 |
ASCIIColors.yellow(f"{args.ssl}")
|
| 193 |
if args.ssl:
|
|
|
|
| 203 |
ASCIIColors.yellow(f"{args.verbose}")
|
| 204 |
ASCIIColors.white(" ├─ History Turns: ", end="")
|
| 205 |
ASCIIColors.yellow(f"{args.history_turns}")
|
| 206 |
+
ASCIIColors.white(" ├─ API Key: ", end="")
|
| 207 |
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
| 208 |
+
ASCIIColors.white(" └─ JWT Auth: ", end="")
|
| 209 |
+
ASCIIColors.yellow("Enabled" if args.auth_accounts else "Disabled")
|
| 210 |
|
| 211 |
# Directory Configuration
|
| 212 |
ASCIIColors.magenta("\n📂 Directory Configuration:")
|
|
|
|
| 244 |
ASCIIColors.yellow(f"{args.embedding_dim}")
|
| 245 |
|
| 246 |
# RAG Configuration
|
|
|
|
| 247 |
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
|
| 248 |
ASCIIColors.white(" ├─ Summary Language: ", end="")
|
| 249 |
+
ASCIIColors.yellow(f"{args.summary_language}")
|
| 250 |
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
|
| 251 |
ASCIIColors.yellow(f"{args.max_parallel_insert}")
|
| 252 |
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
|
|
|
|
| 280 |
protocol = "https" if args.ssl else "http"
|
| 281 |
if args.host == "0.0.0.0":
|
| 282 |
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
| 283 |
+
ASCIIColors.white(" ├─ WebUI (local): ", end="")
|
| 284 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
|
| 285 |
ASCIIColors.white(" ├─ Remote Access: ", end="")
|
| 286 |
ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
|
| 287 |
ASCIIColors.white(" ├─ API Documentation (local): ", end="")
|
| 288 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
|
| 289 |
+
ASCIIColors.white(" └─ Alternative Documentation (local): ", end="")
|
| 290 |
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
|
|
|
|
|
|
|
| 291 |
|
| 292 |
+
ASCIIColors.magenta("\n📝 Note:")
|
| 293 |
+
ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
|
| 294 |
- Use 'localhost' or '127.0.0.1' for local access
|
| 295 |
- Use your machine's IP address for remote access
|
| 296 |
- To find your IP address:
|
|
|
|
| 300 |
else:
|
| 301 |
base_url = f"{protocol}://{args.host}:{args.port}"
|
| 302 |
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
| 303 |
+
ASCIIColors.white(" ├─ WebUI (local): ", end="")
|
| 304 |
ASCIIColors.yellow(f"{base_url}")
|
| 305 |
ASCIIColors.white(" ├─ API Documentation: ", end="")
|
| 306 |
ASCIIColors.yellow(f"{base_url}/docs")
|
| 307 |
ASCIIColors.white(" └─ Alternative Documentation: ", end="")
|
| 308 |
ASCIIColors.yellow(f"{base_url}/redoc")
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
# Security Notice
|
| 311 |
if args.key:
|
| 312 |
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
| 313 |
ASCIIColors.white(""" API Key authentication is enabled.
|
| 314 |
Make sure to include the X-API-Key header in all your requests.
|
| 315 |
""")
|
| 316 |
+
if args.auth_accounts:
|
| 317 |
+
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
| 318 |
+
ASCIIColors.white(""" JWT authentication is enabled.
|
| 319 |
+
Make sure to login before making the request, and include the 'Authorization' in the header.
|
| 320 |
+
""")
|
| 321 |
|
| 322 |
# Ensure splash output flush to system log
|
| 323 |
sys.stdout.flush()
|
lightrag/api/webui/assets/index-CD5HxTy1.css
DELETED
|
Binary file (55.1 kB)
|
|
|
lightrag/api/webui/assets/{index-raheqJeu.js → index-Cma7xY0-.js}
RENAMED
|
Binary files a/lightrag/api/webui/assets/index-raheqJeu.js and b/lightrag/api/webui/assets/index-Cma7xY0-.js differ
|
|
|
lightrag/api/webui/assets/index-QU59h9JG.css
ADDED
|
Binary file (57.1 kB). View file
|
|
|
lightrag/api/webui/index.html
CHANGED
|
Binary files a/lightrag/api/webui/index.html and b/lightrag/api/webui/index.html differ
|
|
|
lightrag/base.py
CHANGED
|
@@ -112,6 +112,32 @@ class StorageNameSpace(ABC):
|
|
| 112 |
async def index_done_callback(self) -> None:
|
| 113 |
"""Commit the storage operations after indexing"""
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
@dataclass
|
| 117 |
class BaseVectorStorage(StorageNameSpace, ABC):
|
|
@@ -127,15 +153,33 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|
| 127 |
|
| 128 |
@abstractmethod
|
| 129 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 130 |
-
"""Insert or update vectors in the storage.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
@abstractmethod
|
| 133 |
async def delete_entity(self, entity_name: str) -> None:
|
| 134 |
-
"""Delete a single entity by its name.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
@abstractmethod
|
| 137 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 138 |
-
"""Delete relations for a given entity.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
@abstractmethod
|
| 141 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
|
@@ -161,6 +205,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
|
|
| 161 |
"""
|
| 162 |
pass
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
@dataclass
|
| 166 |
class BaseKVStorage(StorageNameSpace, ABC):
|
|
@@ -180,7 +237,42 @@ class BaseKVStorage(StorageNameSpace, ABC):
|
|
| 180 |
|
| 181 |
@abstractmethod
|
| 182 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 183 |
-
"""Upsert data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
|
| 186 |
@dataclass
|
|
@@ -205,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
| 205 |
|
| 206 |
@abstractmethod
|
| 207 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
| 208 |
-
"""Get
|
| 209 |
|
| 210 |
@abstractmethod
|
| 211 |
async def get_edge(
|
| 212 |
self, source_node_id: str, target_node_id: str
|
| 213 |
) -> dict[str, str] | None:
|
| 214 |
-
"""Get
|
| 215 |
|
| 216 |
@abstractmethod
|
| 217 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
|
@@ -225,7 +317,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
| 225 |
async def upsert_edge(
|
| 226 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 227 |
) -> None:
|
| 228 |
-
"""Delete a node from the graph.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
@abstractmethod
|
| 231 |
async def delete_node(self, node_id: str) -> None:
|
|
@@ -243,9 +341,20 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
| 243 |
|
| 244 |
@abstractmethod
|
| 245 |
async def get_knowledge_graph(
|
| 246 |
-
self, node_label: str, max_depth: int = 3
|
| 247 |
) -> KnowledgeGraph:
|
| 248 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
|
| 251 |
class DocStatus(str, Enum):
|
|
@@ -297,6 +406,10 @@ class DocStatusStorage(BaseKVStorage, ABC):
|
|
| 297 |
) -> dict[str, DocProcessingStatus]:
|
| 298 |
"""Get all documents with a specific status"""
|
| 299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
class StoragesStatus(str, Enum):
|
| 302 |
"""Storages status"""
|
|
|
|
| 112 |
async def index_done_callback(self) -> None:
|
| 113 |
"""Commit the storage operations after indexing"""
|
| 114 |
|
| 115 |
+
@abstractmethod
|
| 116 |
+
async def drop(self) -> dict[str, str]:
|
| 117 |
+
"""Drop all data from storage and clean up resources
|
| 118 |
+
|
| 119 |
+
This abstract method defines the contract for dropping all data from a storage implementation.
|
| 120 |
+
Each storage type must implement this method to:
|
| 121 |
+
1. Clear all data from memory and/or external storage
|
| 122 |
+
2. Remove any associated storage files if applicable
|
| 123 |
+
3. Reset the storage to its initial state
|
| 124 |
+
4. Handle cleanup of any resources
|
| 125 |
+
5. Notify other processes if necessary
|
| 126 |
+
6. This action should persistent the data to disk immediately.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
dict[str, str]: Operation status and message with the following format:
|
| 130 |
+
{
|
| 131 |
+
"status": str, # "success" or "error"
|
| 132 |
+
"message": str # "data dropped" on success, error details on failure
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
Implementation specific:
|
| 136 |
+
- On success: return {"status": "success", "message": "data dropped"}
|
| 137 |
+
- On failure: return {"status": "error", "message": "<error details>"}
|
| 138 |
+
- If not supported: return {"status": "error", "message": "unsupported"}
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
|
| 142 |
@dataclass
|
| 143 |
class BaseVectorStorage(StorageNameSpace, ABC):
|
|
|
|
| 153 |
|
| 154 |
@abstractmethod
|
| 155 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 156 |
+
"""Insert or update vectors in the storage.
|
| 157 |
+
|
| 158 |
+
Importance notes for in-memory storage:
|
| 159 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 160 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 161 |
+
KG-storage-log should be used to avoid data corruption
|
| 162 |
+
"""
|
| 163 |
|
| 164 |
@abstractmethod
|
| 165 |
async def delete_entity(self, entity_name: str) -> None:
|
| 166 |
+
"""Delete a single entity by its name.
|
| 167 |
+
|
| 168 |
+
Importance notes for in-memory storage:
|
| 169 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 170 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 171 |
+
KG-storage-log should be used to avoid data corruption
|
| 172 |
+
"""
|
| 173 |
|
| 174 |
@abstractmethod
|
| 175 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 176 |
+
"""Delete relations for a given entity.
|
| 177 |
+
|
| 178 |
+
Importance notes for in-memory storage:
|
| 179 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 180 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 181 |
+
KG-storage-log should be used to avoid data corruption
|
| 182 |
+
"""
|
| 183 |
|
| 184 |
@abstractmethod
|
| 185 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
|
|
|
| 205 |
"""
|
| 206 |
pass
|
| 207 |
|
| 208 |
+
@abstractmethod
|
| 209 |
+
async def delete(self, ids: list[str]):
|
| 210 |
+
"""Delete vectors with specified IDs
|
| 211 |
+
|
| 212 |
+
Importance notes for in-memory storage:
|
| 213 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 214 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 215 |
+
KG-storage-log should be used to avoid data corruption
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
ids: List of vector IDs to be deleted
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
|
| 222 |
@dataclass
|
| 223 |
class BaseKVStorage(StorageNameSpace, ABC):
|
|
|
|
| 237 |
|
| 238 |
@abstractmethod
|
| 239 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 240 |
+
"""Upsert data
|
| 241 |
+
|
| 242 |
+
Importance notes for in-memory storage:
|
| 243 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 244 |
+
2. update flags to notify other processes that data persistence is needed
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
@abstractmethod
|
| 248 |
+
async def delete(self, ids: list[str]) -> None:
|
| 249 |
+
"""Delete specific records from storage by their IDs
|
| 250 |
+
|
| 251 |
+
Importance notes for in-memory storage:
|
| 252 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 253 |
+
2. update flags to notify other processes that data persistence is needed
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
ids (list[str]): List of document IDs to be deleted from storage
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
None
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
| 263 |
+
"""Delete specific records from storage by cache mode
|
| 264 |
+
|
| 265 |
+
Importance notes for in-memory storage:
|
| 266 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 267 |
+
2. update flags to notify other processes that data persistence is needed
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
modes (list[str]): List of cache modes to be dropped from storage
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
True: if the cache drop successfully
|
| 274 |
+
False: if the cache drop failed, or the cache mode is not supported
|
| 275 |
+
"""
|
| 276 |
|
| 277 |
|
| 278 |
@dataclass
|
|
|
|
| 297 |
|
| 298 |
@abstractmethod
|
| 299 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
| 300 |
+
"""Get node by its label identifier, return only node properties"""
|
| 301 |
|
| 302 |
@abstractmethod
|
| 303 |
async def get_edge(
|
| 304 |
self, source_node_id: str, target_node_id: str
|
| 305 |
) -> dict[str, str] | None:
|
| 306 |
+
"""Get edge properties between two nodes"""
|
| 307 |
|
| 308 |
@abstractmethod
|
| 309 |
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
|
|
|
| 317 |
async def upsert_edge(
|
| 318 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 319 |
) -> None:
|
| 320 |
+
"""Delete a node from the graph.
|
| 321 |
+
|
| 322 |
+
Importance notes for in-memory storage:
|
| 323 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 324 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 325 |
+
KG-storage-log should be used to avoid data corruption
|
| 326 |
+
"""
|
| 327 |
|
| 328 |
@abstractmethod
|
| 329 |
async def delete_node(self, node_id: str) -> None:
|
|
|
|
| 341 |
|
| 342 |
@abstractmethod
|
| 343 |
async def get_knowledge_graph(
|
| 344 |
+
self, node_label: str, max_depth: int = 3, max_nodes: int = 1000
|
| 345 |
) -> KnowledgeGraph:
|
| 346 |
+
"""
|
| 347 |
+
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
node_label: Label of the starting node,* means all nodes
|
| 351 |
+
max_depth: Maximum depth of the subgraph, Defaults to 3
|
| 352 |
+
max_nodes: Maxiumu nodes to return, Defaults to 1000(BFS if possible)
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
| 356 |
+
indicating whether the graph was truncated due to max_nodes limit
|
| 357 |
+
"""
|
| 358 |
|
| 359 |
|
| 360 |
class DocStatus(str, Enum):
|
|
|
|
| 406 |
) -> dict[str, DocProcessingStatus]:
|
| 407 |
"""Get all documents with a specific status"""
|
| 408 |
|
| 409 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
| 410 |
+
"""Drop cache is not supported for Doc Status storage"""
|
| 411 |
+
return False
|
| 412 |
+
|
| 413 |
|
| 414 |
class StoragesStatus(str, Enum):
|
| 415 |
"""Storages status"""
|
lightrag/kg/__init__.py
CHANGED
|
@@ -2,11 +2,10 @@ STORAGE_IMPLEMENTATIONS = {
|
|
| 2 |
"KV_STORAGE": {
|
| 3 |
"implementations": [
|
| 4 |
"JsonKVStorage",
|
| 5 |
-
"MongoKVStorage",
|
| 6 |
"RedisKVStorage",
|
| 7 |
-
"TiDBKVStorage",
|
| 8 |
"PGKVStorage",
|
| 9 |
-
"
|
|
|
|
| 10 |
],
|
| 11 |
"required_methods": ["get_by_id", "upsert"],
|
| 12 |
},
|
|
@@ -14,12 +13,11 @@ STORAGE_IMPLEMENTATIONS = {
|
|
| 14 |
"implementations": [
|
| 15 |
"NetworkXStorage",
|
| 16 |
"Neo4JStorage",
|
| 17 |
-
"MongoGraphStorage",
|
| 18 |
-
"TiDBGraphStorage",
|
| 19 |
-
"AGEStorage",
|
| 20 |
-
"GremlinStorage",
|
| 21 |
"PGGraphStorage",
|
| 22 |
-
"
|
|
|
|
|
|
|
|
|
|
| 23 |
],
|
| 24 |
"required_methods": ["upsert_node", "upsert_edge"],
|
| 25 |
},
|
|
@@ -28,12 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
|
|
| 28 |
"NanoVectorDBStorage",
|
| 29 |
"MilvusVectorDBStorage",
|
| 30 |
"ChromaVectorDBStorage",
|
| 31 |
-
"TiDBVectorDBStorage",
|
| 32 |
"PGVectorStorage",
|
| 33 |
"FaissVectorDBStorage",
|
| 34 |
"QdrantVectorDBStorage",
|
| 35 |
-
"OracleVectorDBStorage",
|
| 36 |
"MongoVectorDBStorage",
|
|
|
|
| 37 |
],
|
| 38 |
"required_methods": ["query", "upsert"],
|
| 39 |
},
|
|
@@ -41,7 +38,6 @@ STORAGE_IMPLEMENTATIONS = {
|
|
| 41 |
"implementations": [
|
| 42 |
"JsonDocStatusStorage",
|
| 43 |
"PGDocStatusStorage",
|
| 44 |
-
"PGDocStatusStorage",
|
| 45 |
"MongoDocStatusStorage",
|
| 46 |
],
|
| 47 |
"required_methods": ["get_docs_by_status"],
|
|
@@ -54,50 +50,32 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
|
| 54 |
"JsonKVStorage": [],
|
| 55 |
"MongoKVStorage": [],
|
| 56 |
"RedisKVStorage": ["REDIS_URI"],
|
| 57 |
-
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
| 58 |
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
| 59 |
-
"OracleKVStorage": [
|
| 60 |
-
"ORACLE_DSN",
|
| 61 |
-
"ORACLE_USER",
|
| 62 |
-
"ORACLE_PASSWORD",
|
| 63 |
-
"ORACLE_CONFIG_DIR",
|
| 64 |
-
],
|
| 65 |
# Graph Storage Implementations
|
| 66 |
"NetworkXStorage": [],
|
| 67 |
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
| 68 |
"MongoGraphStorage": [],
|
| 69 |
-
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
| 70 |
"AGEStorage": [
|
| 71 |
"AGE_POSTGRES_DB",
|
| 72 |
"AGE_POSTGRES_USER",
|
| 73 |
"AGE_POSTGRES_PASSWORD",
|
| 74 |
],
|
| 75 |
-
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
| 76 |
"PGGraphStorage": [
|
| 77 |
"POSTGRES_USER",
|
| 78 |
"POSTGRES_PASSWORD",
|
| 79 |
"POSTGRES_DATABASE",
|
| 80 |
],
|
| 81 |
-
"OracleGraphStorage": [
|
| 82 |
-
"ORACLE_DSN",
|
| 83 |
-
"ORACLE_USER",
|
| 84 |
-
"ORACLE_PASSWORD",
|
| 85 |
-
"ORACLE_CONFIG_DIR",
|
| 86 |
-
],
|
| 87 |
# Vector Storage Implementations
|
| 88 |
"NanoVectorDBStorage": [],
|
| 89 |
"MilvusVectorDBStorage": [],
|
| 90 |
"ChromaVectorDBStorage": [],
|
| 91 |
-
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
| 92 |
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
| 93 |
"FaissVectorDBStorage": [],
|
| 94 |
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
| 95 |
-
"OracleVectorDBStorage": [
|
| 96 |
-
"ORACLE_DSN",
|
| 97 |
-
"ORACLE_USER",
|
| 98 |
-
"ORACLE_PASSWORD",
|
| 99 |
-
"ORACLE_CONFIG_DIR",
|
| 100 |
-
],
|
| 101 |
"MongoVectorDBStorage": [],
|
| 102 |
# Document Status Storage Implementations
|
| 103 |
"JsonDocStatusStorage": [],
|
|
@@ -112,9 +90,6 @@ STORAGES = {
|
|
| 112 |
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
| 113 |
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
| 114 |
"Neo4JStorage": ".kg.neo4j_impl",
|
| 115 |
-
"OracleKVStorage": ".kg.oracle_impl",
|
| 116 |
-
"OracleGraphStorage": ".kg.oracle_impl",
|
| 117 |
-
"OracleVectorDBStorage": ".kg.oracle_impl",
|
| 118 |
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
| 119 |
"MongoKVStorage": ".kg.mongo_impl",
|
| 120 |
"MongoDocStatusStorage": ".kg.mongo_impl",
|
|
@@ -122,14 +97,14 @@ STORAGES = {
|
|
| 122 |
"MongoVectorDBStorage": ".kg.mongo_impl",
|
| 123 |
"RedisKVStorage": ".kg.redis_impl",
|
| 124 |
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
| 125 |
-
"TiDBKVStorage": ".kg.tidb_impl",
|
| 126 |
-
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
| 127 |
-
"TiDBGraphStorage": ".kg.tidb_impl",
|
| 128 |
"PGKVStorage": ".kg.postgres_impl",
|
| 129 |
"PGVectorStorage": ".kg.postgres_impl",
|
| 130 |
"AGEStorage": ".kg.age_impl",
|
| 131 |
"PGGraphStorage": ".kg.postgres_impl",
|
| 132 |
-
"GremlinStorage": ".kg.gremlin_impl",
|
| 133 |
"PGDocStatusStorage": ".kg.postgres_impl",
|
| 134 |
"FaissVectorDBStorage": ".kg.faiss_impl",
|
| 135 |
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
|
|
|
| 2 |
"KV_STORAGE": {
|
| 3 |
"implementations": [
|
| 4 |
"JsonKVStorage",
|
|
|
|
| 5 |
"RedisKVStorage",
|
|
|
|
| 6 |
"PGKVStorage",
|
| 7 |
+
"MongoKVStorage",
|
| 8 |
+
# "TiDBKVStorage",
|
| 9 |
],
|
| 10 |
"required_methods": ["get_by_id", "upsert"],
|
| 11 |
},
|
|
|
|
| 13 |
"implementations": [
|
| 14 |
"NetworkXStorage",
|
| 15 |
"Neo4JStorage",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"PGGraphStorage",
|
| 17 |
+
# "AGEStorage",
|
| 18 |
+
# "MongoGraphStorage",
|
| 19 |
+
# "TiDBGraphStorage",
|
| 20 |
+
# "GremlinStorage",
|
| 21 |
],
|
| 22 |
"required_methods": ["upsert_node", "upsert_edge"],
|
| 23 |
},
|
|
|
|
| 26 |
"NanoVectorDBStorage",
|
| 27 |
"MilvusVectorDBStorage",
|
| 28 |
"ChromaVectorDBStorage",
|
|
|
|
| 29 |
"PGVectorStorage",
|
| 30 |
"FaissVectorDBStorage",
|
| 31 |
"QdrantVectorDBStorage",
|
|
|
|
| 32 |
"MongoVectorDBStorage",
|
| 33 |
+
# "TiDBVectorDBStorage",
|
| 34 |
],
|
| 35 |
"required_methods": ["query", "upsert"],
|
| 36 |
},
|
|
|
|
| 38 |
"implementations": [
|
| 39 |
"JsonDocStatusStorage",
|
| 40 |
"PGDocStatusStorage",
|
|
|
|
| 41 |
"MongoDocStatusStorage",
|
| 42 |
],
|
| 43 |
"required_methods": ["get_docs_by_status"],
|
|
|
|
| 50 |
"JsonKVStorage": [],
|
| 51 |
"MongoKVStorage": [],
|
| 52 |
"RedisKVStorage": ["REDIS_URI"],
|
| 53 |
+
# "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
| 54 |
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# Graph Storage Implementations
|
| 56 |
"NetworkXStorage": [],
|
| 57 |
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
|
| 58 |
"MongoGraphStorage": [],
|
| 59 |
+
# "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
| 60 |
"AGEStorage": [
|
| 61 |
"AGE_POSTGRES_DB",
|
| 62 |
"AGE_POSTGRES_USER",
|
| 63 |
"AGE_POSTGRES_PASSWORD",
|
| 64 |
],
|
| 65 |
+
# "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
|
| 66 |
"PGGraphStorage": [
|
| 67 |
"POSTGRES_USER",
|
| 68 |
"POSTGRES_PASSWORD",
|
| 69 |
"POSTGRES_DATABASE",
|
| 70 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
# Vector Storage Implementations
|
| 72 |
"NanoVectorDBStorage": [],
|
| 73 |
"MilvusVectorDBStorage": [],
|
| 74 |
"ChromaVectorDBStorage": [],
|
| 75 |
+
# "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
|
| 76 |
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
| 77 |
"FaissVectorDBStorage": [],
|
| 78 |
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
"MongoVectorDBStorage": [],
|
| 80 |
# Document Status Storage Implementations
|
| 81 |
"JsonDocStatusStorage": [],
|
|
|
|
| 90 |
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
| 91 |
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
|
| 92 |
"Neo4JStorage": ".kg.neo4j_impl",
|
|
|
|
|
|
|
|
|
|
| 93 |
"MilvusVectorDBStorage": ".kg.milvus_impl",
|
| 94 |
"MongoKVStorage": ".kg.mongo_impl",
|
| 95 |
"MongoDocStatusStorage": ".kg.mongo_impl",
|
|
|
|
| 97 |
"MongoVectorDBStorage": ".kg.mongo_impl",
|
| 98 |
"RedisKVStorage": ".kg.redis_impl",
|
| 99 |
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
| 100 |
+
# "TiDBKVStorage": ".kg.tidb_impl",
|
| 101 |
+
# "TiDBVectorDBStorage": ".kg.tidb_impl",
|
| 102 |
+
# "TiDBGraphStorage": ".kg.tidb_impl",
|
| 103 |
"PGKVStorage": ".kg.postgres_impl",
|
| 104 |
"PGVectorStorage": ".kg.postgres_impl",
|
| 105 |
"AGEStorage": ".kg.age_impl",
|
| 106 |
"PGGraphStorage": ".kg.postgres_impl",
|
| 107 |
+
# "GremlinStorage": ".kg.gremlin_impl",
|
| 108 |
"PGDocStatusStorage": ".kg.postgres_impl",
|
| 109 |
"FaissVectorDBStorage": ".kg.faiss_impl",
|
| 110 |
"QdrantVectorDBStorage": ".kg.qdrant_impl",
|
lightrag/kg/age_impl.py
CHANGED
|
@@ -34,9 +34,9 @@ if not pm.is_installed("psycopg-pool"):
|
|
| 34 |
if not pm.is_installed("asyncpg"):
|
| 35 |
pm.install("asyncpg")
|
| 36 |
|
| 37 |
-
import psycopg
|
| 38 |
-
from psycopg.rows import namedtuple_row
|
| 39 |
-
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
| 40 |
|
| 41 |
|
| 42 |
class AGEQueryException(Exception):
|
|
@@ -871,3 +871,21 @@ class AGEStorage(BaseGraphStorage):
|
|
| 871 |
async def index_done_callback(self) -> None:
|
| 872 |
# AGES handles persistence automatically
|
| 873 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
if not pm.is_installed("asyncpg"):
|
| 35 |
pm.install("asyncpg")
|
| 36 |
|
| 37 |
+
import psycopg # type: ignore
|
| 38 |
+
from psycopg.rows import namedtuple_row # type: ignore
|
| 39 |
+
from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore
|
| 40 |
|
| 41 |
|
| 42 |
class AGEQueryException(Exception):
|
|
|
|
| 871 |
async def index_done_callback(self) -> None:
|
| 872 |
# AGES handles persistence automatically
|
| 873 |
pass
|
| 874 |
+
|
| 875 |
+
async def drop(self) -> dict[str, str]:
|
| 876 |
+
"""Drop the storage by removing all nodes and relationships in the graph.
|
| 877 |
+
|
| 878 |
+
Returns:
|
| 879 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
| 880 |
+
"""
|
| 881 |
+
try:
|
| 882 |
+
query = """
|
| 883 |
+
MATCH (n)
|
| 884 |
+
DETACH DELETE n
|
| 885 |
+
"""
|
| 886 |
+
await self._query(query)
|
| 887 |
+
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
|
| 888 |
+
return {"status": "success", "message": "graph data dropped"}
|
| 889 |
+
except Exception as e:
|
| 890 |
+
logger.error(f"Error dropping graph {self.graph_name}: {e}")
|
| 891 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/chroma_impl.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from typing import Any, final
|
| 4 |
import numpy as np
|
|
@@ -10,8 +11,8 @@ import pipmaster as pm
|
|
| 10 |
if not pm.is_installed("chromadb"):
|
| 11 |
pm.install("chromadb")
|
| 12 |
|
| 13 |
-
from chromadb import HttpClient, PersistentClient
|
| 14 |
-
from chromadb.config import Settings
|
| 15 |
|
| 16 |
|
| 17 |
@final
|
|
@@ -335,3 +336,28 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|
| 335 |
except Exception as e:
|
| 336 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
| 337 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
+
import os
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from typing import Any, final
|
| 5 |
import numpy as np
|
|
|
|
| 11 |
if not pm.is_installed("chromadb"):
|
| 12 |
pm.install("chromadb")
|
| 13 |
|
| 14 |
+
from chromadb import HttpClient, PersistentClient # type: ignore
|
| 15 |
+
from chromadb.config import Settings # type: ignore
|
| 16 |
|
| 17 |
|
| 18 |
@final
|
|
|
|
| 336 |
except Exception as e:
|
| 337 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
| 338 |
return []
|
| 339 |
+
|
| 340 |
+
async def drop(self) -> dict[str, str]:
|
| 341 |
+
"""Drop all vector data from storage and clean up resources
|
| 342 |
+
|
| 343 |
+
This method will delete all documents from the ChromaDB collection.
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
dict[str, str]: Operation status and message
|
| 347 |
+
- On success: {"status": "success", "message": "data dropped"}
|
| 348 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
| 349 |
+
"""
|
| 350 |
+
try:
|
| 351 |
+
# Get all IDs in the collection
|
| 352 |
+
result = self._collection.get(include=[])
|
| 353 |
+
if result and result["ids"] and len(result["ids"]) > 0:
|
| 354 |
+
# Delete all documents
|
| 355 |
+
self._collection.delete(ids=result["ids"])
|
| 356 |
+
|
| 357 |
+
logger.info(
|
| 358 |
+
f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
|
| 359 |
+
)
|
| 360 |
+
return {"status": "success", "message": "data dropped"}
|
| 361 |
+
except Exception as e:
|
| 362 |
+
logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
|
| 363 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/faiss_impl.py
CHANGED
|
@@ -11,16 +11,20 @@ import pipmaster as pm
|
|
| 11 |
from lightrag.utils import logger, compute_mdhash_id
|
| 12 |
from lightrag.base import BaseVectorStorage
|
| 13 |
|
| 14 |
-
if not pm.is_installed("faiss"):
|
| 15 |
-
pm.install("faiss")
|
| 16 |
-
|
| 17 |
-
import faiss # type: ignore
|
| 18 |
from .shared_storage import (
|
| 19 |
get_storage_lock,
|
| 20 |
get_update_flag,
|
| 21 |
set_all_update_flags,
|
| 22 |
)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
@final
|
| 26 |
@dataclass
|
|
@@ -217,6 +221,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
| 217 |
async def delete(self, ids: list[str]):
|
| 218 |
"""
|
| 219 |
Delete vectors for the provided custom IDs.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
"""
|
| 221 |
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
| 222 |
to_remove = []
|
|
@@ -232,13 +241,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
| 232 |
)
|
| 233 |
|
| 234 |
async def delete_entity(self, entity_name: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
| 236 |
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
| 237 |
await self.delete([entity_id])
|
| 238 |
|
| 239 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 240 |
"""
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
| 242 |
"""
|
| 243 |
logger.debug(f"Searching relations for entity {entity_name}")
|
| 244 |
relations = []
|
|
@@ -429,3 +447,44 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
| 429 |
results.append({**metadata, "id": metadata.get("__id__")})
|
| 430 |
|
| 431 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from lightrag.utils import logger, compute_mdhash_id
|
| 12 |
from lightrag.base import BaseVectorStorage
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from .shared_storage import (
|
| 15 |
get_storage_lock,
|
| 16 |
get_update_flag,
|
| 17 |
set_all_update_flags,
|
| 18 |
)
|
| 19 |
|
| 20 |
+
import faiss # type: ignore
|
| 21 |
+
|
| 22 |
+
USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
|
| 23 |
+
FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
|
| 24 |
+
|
| 25 |
+
if not pm.is_installed(FAISS_PACKAGE):
|
| 26 |
+
pm.install(FAISS_PACKAGE)
|
| 27 |
+
|
| 28 |
|
| 29 |
@final
|
| 30 |
@dataclass
|
|
|
|
| 221 |
async def delete(self, ids: list[str]):
|
| 222 |
"""
|
| 223 |
Delete vectors for the provided custom IDs.
|
| 224 |
+
|
| 225 |
+
Importance notes:
|
| 226 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 227 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 228 |
+
KG-storage-log should be used to avoid data corruption
|
| 229 |
"""
|
| 230 |
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
| 231 |
to_remove = []
|
|
|
|
| 241 |
)
|
| 242 |
|
| 243 |
async def delete_entity(self, entity_name: str) -> None:
|
| 244 |
+
"""
|
| 245 |
+
Importance notes:
|
| 246 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 247 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 248 |
+
KG-storage-log should be used to avoid data corruption
|
| 249 |
+
"""
|
| 250 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
| 251 |
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
| 252 |
await self.delete([entity_id])
|
| 253 |
|
| 254 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 255 |
"""
|
| 256 |
+
Importance notes:
|
| 257 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 258 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 259 |
+
KG-storage-log should be used to avoid data corruption
|
| 260 |
"""
|
| 261 |
logger.debug(f"Searching relations for entity {entity_name}")
|
| 262 |
relations = []
|
|
|
|
| 447 |
results.append({**metadata, "id": metadata.get("__id__")})
|
| 448 |
|
| 449 |
return results
|
| 450 |
+
|
| 451 |
+
async def drop(self) -> dict[str, str]:
|
| 452 |
+
"""Drop all vector data from storage and clean up resources
|
| 453 |
+
|
| 454 |
+
This method will:
|
| 455 |
+
1. Remove the vector database storage file if it exists
|
| 456 |
+
2. Reinitialize the vector database client
|
| 457 |
+
3. Update flags to notify other processes
|
| 458 |
+
4. Changes is persisted to disk immediately
|
| 459 |
+
|
| 460 |
+
This method will remove all vectors from the Faiss index and delete the storage files.
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
dict[str, str]: Operation status and message
|
| 464 |
+
- On success: {"status": "success", "message": "data dropped"}
|
| 465 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
| 466 |
+
"""
|
| 467 |
+
try:
|
| 468 |
+
async with self._storage_lock:
|
| 469 |
+
# Reset the index
|
| 470 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
| 471 |
+
self._id_to_meta = {}
|
| 472 |
+
|
| 473 |
+
# Remove storage files if they exist
|
| 474 |
+
if os.path.exists(self._faiss_index_file):
|
| 475 |
+
os.remove(self._faiss_index_file)
|
| 476 |
+
if os.path.exists(self._meta_file):
|
| 477 |
+
os.remove(self._meta_file)
|
| 478 |
+
|
| 479 |
+
self._id_to_meta = {}
|
| 480 |
+
self._load_faiss_index()
|
| 481 |
+
|
| 482 |
+
# Notify other processes
|
| 483 |
+
await set_all_update_flags(self.namespace)
|
| 484 |
+
self.storage_updated.value = False
|
| 485 |
+
|
| 486 |
+
logger.info(f"Process {os.getpid()} drop FAISS index {self.namespace}")
|
| 487 |
+
return {"status": "success", "message": "data dropped"}
|
| 488 |
+
except Exception as e:
|
| 489 |
+
logger.error(f"Error dropping FAISS index {self.namespace}: {e}")
|
| 490 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/gremlin_impl.py
CHANGED
|
@@ -24,9 +24,9 @@ from ..base import BaseGraphStorage
|
|
| 24 |
if not pm.is_installed("gremlinpython"):
|
| 25 |
pm.install("gremlinpython")
|
| 26 |
|
| 27 |
-
from gremlin_python.driver import client, serializer
|
| 28 |
-
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
| 29 |
-
from gremlin_python.driver.protocol import GremlinServerError
|
| 30 |
|
| 31 |
|
| 32 |
@final
|
|
@@ -695,3 +695,24 @@ class GremlinStorage(BaseGraphStorage):
|
|
| 695 |
except Exception as e:
|
| 696 |
logger.error(f"Error during edge deletion: {str(e)}")
|
| 697 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
if not pm.is_installed("gremlinpython"):
|
| 25 |
pm.install("gremlinpython")
|
| 26 |
|
| 27 |
+
from gremlin_python.driver import client, serializer # type: ignore
|
| 28 |
+
from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore
|
| 29 |
+
from gremlin_python.driver.protocol import GremlinServerError # type: ignore
|
| 30 |
|
| 31 |
|
| 32 |
@final
|
|
|
|
| 695 |
except Exception as e:
|
| 696 |
logger.error(f"Error during edge deletion: {str(e)}")
|
| 697 |
raise
|
| 698 |
+
|
| 699 |
+
async def drop(self) -> dict[str, str]:
|
| 700 |
+
"""Drop the storage by removing all nodes and relationships in the graph.
|
| 701 |
+
|
| 702 |
+
This function deletes all nodes with the specified graph name property,
|
| 703 |
+
which automatically removes all associated edges.
|
| 704 |
+
|
| 705 |
+
Returns:
|
| 706 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
| 707 |
+
"""
|
| 708 |
+
try:
|
| 709 |
+
query = f"""g
|
| 710 |
+
.V().has('graph', {self.graph_name})
|
| 711 |
+
.drop()
|
| 712 |
+
"""
|
| 713 |
+
await self._query(query)
|
| 714 |
+
logger.info(f"Successfully dropped all data from graph {self.graph_name}")
|
| 715 |
+
return {"status": "success", "message": "graph data dropped"}
|
| 716 |
+
except Exception as e:
|
| 717 |
+
logger.error(f"Error dropping graph {self.graph_name}: {e}")
|
| 718 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/json_doc_status_impl.py
CHANGED
|
@@ -109,6 +109,11 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|
| 109 |
await clear_all_update_flags(self.namespace)
|
| 110 |
|
| 111 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
if not data:
|
| 113 |
return
|
| 114 |
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
|
@@ -122,16 +127,50 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|
| 122 |
async with self._storage_lock:
|
| 123 |
return self._data.get(id)
|
| 124 |
|
| 125 |
-
async def delete(self, doc_ids: list[str]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
async with self._storage_lock:
|
|
|
|
| 127 |
for doc_id in doc_ids:
|
| 128 |
-
self._data.pop(doc_id, None)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
await clear_all_update_flags(self.namespace)
|
| 110 |
|
| 111 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 112 |
+
"""
|
| 113 |
+
Importance notes for in-memory storage:
|
| 114 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 115 |
+
2. update flags to notify other processes that data persistence is needed
|
| 116 |
+
"""
|
| 117 |
if not data:
|
| 118 |
return
|
| 119 |
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
|
|
|
| 127 |
async with self._storage_lock:
|
| 128 |
return self._data.get(id)
|
| 129 |
|
| 130 |
+
async def delete(self, doc_ids: list[str]) -> None:
|
| 131 |
+
"""Delete specific records from storage by their IDs
|
| 132 |
+
|
| 133 |
+
Importance notes for in-memory storage:
|
| 134 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 135 |
+
2. update flags to notify other processes that data persistence is needed
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
ids (list[str]): List of document IDs to be deleted from storage
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
None
|
| 142 |
+
"""
|
| 143 |
async with self._storage_lock:
|
| 144 |
+
any_deleted = False
|
| 145 |
for doc_id in doc_ids:
|
| 146 |
+
result = self._data.pop(doc_id, None)
|
| 147 |
+
if result is not None:
|
| 148 |
+
any_deleted = True
|
| 149 |
|
| 150 |
+
if any_deleted:
|
| 151 |
+
await set_all_update_flags(self.namespace)
|
| 152 |
+
|
| 153 |
+
async def drop(self) -> dict[str, str]:
|
| 154 |
+
"""Drop all document status data from storage and clean up resources
|
| 155 |
+
|
| 156 |
+
This method will:
|
| 157 |
+
1. Clear all document status data from memory
|
| 158 |
+
2. Update flags to notify other processes
|
| 159 |
+
3. Trigger index_done_callback to save the empty state
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
dict[str, str]: Operation status and message
|
| 163 |
+
- On success: {"status": "success", "message": "data dropped"}
|
| 164 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
| 165 |
+
"""
|
| 166 |
+
try:
|
| 167 |
+
async with self._storage_lock:
|
| 168 |
+
self._data.clear()
|
| 169 |
+
await set_all_update_flags(self.namespace)
|
| 170 |
+
|
| 171 |
+
await self.index_done_callback()
|
| 172 |
+
logger.info(f"Process {os.getpid()} drop {self.namespace}")
|
| 173 |
+
return {"status": "success", "message": "data dropped"}
|
| 174 |
+
except Exception as e:
|
| 175 |
+
logger.error(f"Error dropping {self.namespace}: {e}")
|
| 176 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/json_kv_impl.py
CHANGED
|
@@ -114,6 +114,11 @@ class JsonKVStorage(BaseKVStorage):
|
|
| 114 |
return set(keys) - set(self._data.keys())
|
| 115 |
|
| 116 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
if not data:
|
| 118 |
return
|
| 119 |
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
|
@@ -122,8 +127,73 @@ class JsonKVStorage(BaseKVStorage):
|
|
| 122 |
await set_all_update_flags(self.namespace)
|
| 123 |
|
| 124 |
async def delete(self, ids: list[str]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
async with self._storage_lock:
|
|
|
|
| 126 |
for doc_id in ids:
|
| 127 |
-
self._data.pop(doc_id, None)
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
return set(keys) - set(self._data.keys())
|
| 115 |
|
| 116 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 117 |
+
"""
|
| 118 |
+
Importance notes for in-memory storage:
|
| 119 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 120 |
+
2. update flags to notify other processes that data persistence is needed
|
| 121 |
+
"""
|
| 122 |
if not data:
|
| 123 |
return
|
| 124 |
logger.info(f"Inserting {len(data)} records to {self.namespace}")
|
|
|
|
| 127 |
await set_all_update_flags(self.namespace)
|
| 128 |
|
| 129 |
async def delete(self, ids: list[str]) -> None:
|
| 130 |
+
"""Delete specific records from storage by their IDs
|
| 131 |
+
|
| 132 |
+
Importance notes for in-memory storage:
|
| 133 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 134 |
+
2. update flags to notify other processes that data persistence is needed
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
ids (list[str]): List of document IDs to be deleted from storage
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
None
|
| 141 |
+
"""
|
| 142 |
async with self._storage_lock:
|
| 143 |
+
any_deleted = False
|
| 144 |
for doc_id in ids:
|
| 145 |
+
result = self._data.pop(doc_id, None)
|
| 146 |
+
if result is not None:
|
| 147 |
+
any_deleted = True
|
| 148 |
+
|
| 149 |
+
if any_deleted:
|
| 150 |
+
await set_all_update_flags(self.namespace)
|
| 151 |
+
|
| 152 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
| 153 |
+
"""Delete specific records from storage by by cache mode
|
| 154 |
+
|
| 155 |
+
Importance notes for in-memory storage:
|
| 156 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 157 |
+
2. update flags to notify other processes that data persistence is needed
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
ids (list[str]): List of cache mode to be drop from storage
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
True: if the cache drop successfully
|
| 164 |
+
False: if the cache drop failed
|
| 165 |
+
"""
|
| 166 |
+
if not modes:
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
await self.delete(modes)
|
| 171 |
+
return True
|
| 172 |
+
except Exception:
|
| 173 |
+
return False
|
| 174 |
+
|
| 175 |
+
async def drop(self) -> dict[str, str]:
|
| 176 |
+
"""Drop all data from storage and clean up resources
|
| 177 |
+
This action will persistent the data to disk immediately.
|
| 178 |
+
|
| 179 |
+
This method will:
|
| 180 |
+
1. Clear all data from memory
|
| 181 |
+
2. Update flags to notify other processes
|
| 182 |
+
3. Trigger index_done_callback to save the empty state
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
dict[str, str]: Operation status and message
|
| 186 |
+
- On success: {"status": "success", "message": "data dropped"}
|
| 187 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
| 188 |
+
"""
|
| 189 |
+
try:
|
| 190 |
+
async with self._storage_lock:
|
| 191 |
+
self._data.clear()
|
| 192 |
+
await set_all_update_flags(self.namespace)
|
| 193 |
+
|
| 194 |
+
await self.index_done_callback()
|
| 195 |
+
logger.info(f"Process {os.getpid()} drop {self.namespace}")
|
| 196 |
+
return {"status": "success", "message": "data dropped"}
|
| 197 |
+
except Exception as e:
|
| 198 |
+
logger.error(f"Error dropping {self.namespace}: {e}")
|
| 199 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/milvus_impl.py
CHANGED
|
@@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
|
|
| 15 |
pm.install("pymilvus")
|
| 16 |
|
| 17 |
import configparser
|
| 18 |
-
from pymilvus import MilvusClient
|
| 19 |
|
| 20 |
config = configparser.ConfigParser()
|
| 21 |
config.read("config.ini", "utf-8")
|
|
@@ -287,3 +287,33 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|
| 287 |
except Exception as e:
|
| 288 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
| 289 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
pm.install("pymilvus")
|
| 16 |
|
| 17 |
import configparser
|
| 18 |
+
from pymilvus import MilvusClient # type: ignore
|
| 19 |
|
| 20 |
config = configparser.ConfigParser()
|
| 21 |
config.read("config.ini", "utf-8")
|
|
|
|
| 287 |
except Exception as e:
|
| 288 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
| 289 |
return []
|
| 290 |
+
|
| 291 |
+
async def drop(self) -> dict[str, str]:
|
| 292 |
+
"""Drop all vector data from storage and clean up resources
|
| 293 |
+
|
| 294 |
+
This method will delete all data from the Milvus collection.
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
dict[str, str]: Operation status and message
|
| 298 |
+
- On success: {"status": "success", "message": "data dropped"}
|
| 299 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
| 300 |
+
"""
|
| 301 |
+
try:
|
| 302 |
+
# Drop the collection and recreate it
|
| 303 |
+
if self._client.has_collection(self.namespace):
|
| 304 |
+
self._client.drop_collection(self.namespace)
|
| 305 |
+
|
| 306 |
+
# Recreate the collection
|
| 307 |
+
MilvusVectorDBStorage.create_collection_if_not_exist(
|
| 308 |
+
self._client,
|
| 309 |
+
self.namespace,
|
| 310 |
+
dimension=self.embedding_func.embedding_dim,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
logger.info(
|
| 314 |
+
f"Process {os.getpid()} drop Milvus collection {self.namespace}"
|
| 315 |
+
)
|
| 316 |
+
return {"status": "success", "message": "data dropped"}
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.error(f"Error dropping Milvus collection {self.namespace}: {e}")
|
| 319 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/mongo_impl.py
CHANGED
|
@@ -25,13 +25,13 @@ if not pm.is_installed("pymongo"):
|
|
| 25 |
if not pm.is_installed("motor"):
|
| 26 |
pm.install("motor")
|
| 27 |
|
| 28 |
-
from motor.motor_asyncio import (
|
| 29 |
AsyncIOMotorClient,
|
| 30 |
AsyncIOMotorDatabase,
|
| 31 |
AsyncIOMotorCollection,
|
| 32 |
)
|
| 33 |
-
from pymongo.operations import SearchIndexModel
|
| 34 |
-
from pymongo.errors import PyMongoError
|
| 35 |
|
| 36 |
config = configparser.ConfigParser()
|
| 37 |
config.read("config.ini", "utf-8")
|
|
@@ -150,6 +150,66 @@ class MongoKVStorage(BaseKVStorage):
|
|
| 150 |
# Mongo handles persistence automatically
|
| 151 |
pass
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
@final
|
| 155 |
@dataclass
|
|
@@ -230,6 +290,27 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|
| 230 |
# Mongo handles persistence automatically
|
| 231 |
pass
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
@final
|
| 235 |
@dataclass
|
|
@@ -840,6 +921,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
| 840 |
|
| 841 |
logger.debug(f"Successfully deleted edges: {edges}")
|
| 842 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 843 |
|
| 844 |
@final
|
| 845 |
@dataclass
|
|
@@ -1127,6 +1229,31 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|
| 1127 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
| 1128 |
return []
|
| 1129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1130 |
|
| 1131 |
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
| 1132 |
collection_names = await db.list_collection_names()
|
|
|
|
| 25 |
if not pm.is_installed("motor"):
|
| 26 |
pm.install("motor")
|
| 27 |
|
| 28 |
+
from motor.motor_asyncio import ( # type: ignore
|
| 29 |
AsyncIOMotorClient,
|
| 30 |
AsyncIOMotorDatabase,
|
| 31 |
AsyncIOMotorCollection,
|
| 32 |
)
|
| 33 |
+
from pymongo.operations import SearchIndexModel # type: ignore
|
| 34 |
+
from pymongo.errors import PyMongoError # type: ignore
|
| 35 |
|
| 36 |
config = configparser.ConfigParser()
|
| 37 |
config.read("config.ini", "utf-8")
|
|
|
|
| 150 |
# Mongo handles persistence automatically
|
| 151 |
pass
|
| 152 |
|
| 153 |
+
async def delete(self, ids: list[str]) -> None:
|
| 154 |
+
"""Delete documents with specified IDs
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
ids: List of document IDs to be deleted
|
| 158 |
+
"""
|
| 159 |
+
if not ids:
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
result = await self._data.delete_many({"_id": {"$in": ids}})
|
| 164 |
+
logger.info(
|
| 165 |
+
f"Deleted {result.deleted_count} documents from {self.namespace}"
|
| 166 |
+
)
|
| 167 |
+
except PyMongoError as e:
|
| 168 |
+
logger.error(f"Error deleting documents from {self.namespace}: {e}")
|
| 169 |
+
|
| 170 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
| 171 |
+
"""Delete specific records from storage by cache mode
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
modes (list[str]): List of cache modes to be dropped from storage
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
bool: True if successful, False otherwise
|
| 178 |
+
"""
|
| 179 |
+
if not modes:
|
| 180 |
+
return False
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
# Build regex pattern to match documents with the specified modes
|
| 184 |
+
pattern = f"^({'|'.join(modes)})_"
|
| 185 |
+
result = await self._data.delete_many({"_id": {"$regex": pattern}})
|
| 186 |
+
logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
|
| 187 |
+
return True
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
async def drop(self) -> dict[str, str]:
|
| 193 |
+
"""Drop the storage by removing all documents in the collection.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
| 197 |
+
"""
|
| 198 |
+
try:
|
| 199 |
+
result = await self._data.delete_many({})
|
| 200 |
+
deleted_count = result.deleted_count
|
| 201 |
+
|
| 202 |
+
logger.info(
|
| 203 |
+
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
|
| 204 |
+
)
|
| 205 |
+
return {
|
| 206 |
+
"status": "success",
|
| 207 |
+
"message": f"{deleted_count} documents dropped",
|
| 208 |
+
}
|
| 209 |
+
except PyMongoError as e:
|
| 210 |
+
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
| 211 |
+
return {"status": "error", "message": str(e)}
|
| 212 |
+
|
| 213 |
|
| 214 |
@final
|
| 215 |
@dataclass
|
|
|
|
| 290 |
# Mongo handles persistence automatically
|
| 291 |
pass
|
| 292 |
|
| 293 |
+
async def drop(self) -> dict[str, str]:
|
| 294 |
+
"""Drop the storage by removing all documents in the collection.
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
| 298 |
+
"""
|
| 299 |
+
try:
|
| 300 |
+
result = await self._data.delete_many({})
|
| 301 |
+
deleted_count = result.deleted_count
|
| 302 |
+
|
| 303 |
+
logger.info(
|
| 304 |
+
f"Dropped {deleted_count} documents from doc status {self._collection_name}"
|
| 305 |
+
)
|
| 306 |
+
return {
|
| 307 |
+
"status": "success",
|
| 308 |
+
"message": f"{deleted_count} documents dropped",
|
| 309 |
+
}
|
| 310 |
+
except PyMongoError as e:
|
| 311 |
+
logger.error(f"Error dropping doc status {self._collection_name}: {e}")
|
| 312 |
+
return {"status": "error", "message": str(e)}
|
| 313 |
+
|
| 314 |
|
| 315 |
@final
|
| 316 |
@dataclass
|
|
|
|
| 921 |
|
| 922 |
logger.debug(f"Successfully deleted edges: {edges}")
|
| 923 |
|
| 924 |
+
async def drop(self) -> dict[str, str]:
|
| 925 |
+
"""Drop the storage by removing all documents in the collection.
|
| 926 |
+
|
| 927 |
+
Returns:
|
| 928 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
| 929 |
+
"""
|
| 930 |
+
try:
|
| 931 |
+
result = await self.collection.delete_many({})
|
| 932 |
+
deleted_count = result.deleted_count
|
| 933 |
+
|
| 934 |
+
logger.info(
|
| 935 |
+
f"Dropped {deleted_count} documents from graph {self._collection_name}"
|
| 936 |
+
)
|
| 937 |
+
return {
|
| 938 |
+
"status": "success",
|
| 939 |
+
"message": f"{deleted_count} documents dropped",
|
| 940 |
+
}
|
| 941 |
+
except PyMongoError as e:
|
| 942 |
+
logger.error(f"Error dropping graph {self._collection_name}: {e}")
|
| 943 |
+
return {"status": "error", "message": str(e)}
|
| 944 |
+
|
| 945 |
|
| 946 |
@final
|
| 947 |
@dataclass
|
|
|
|
| 1229 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
| 1230 |
return []
|
| 1231 |
|
| 1232 |
+
async def drop(self) -> dict[str, str]:
|
| 1233 |
+
"""Drop the storage by removing all documents in the collection and recreating vector index.
|
| 1234 |
+
|
| 1235 |
+
Returns:
|
| 1236 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
| 1237 |
+
"""
|
| 1238 |
+
try:
|
| 1239 |
+
# Delete all documents
|
| 1240 |
+
result = await self._data.delete_many({})
|
| 1241 |
+
deleted_count = result.deleted_count
|
| 1242 |
+
|
| 1243 |
+
# Recreate vector index
|
| 1244 |
+
await self.create_vector_index_if_not_exists()
|
| 1245 |
+
|
| 1246 |
+
logger.info(
|
| 1247 |
+
f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
|
| 1248 |
+
)
|
| 1249 |
+
return {
|
| 1250 |
+
"status": "success",
|
| 1251 |
+
"message": f"{deleted_count} documents dropped and vector index recreated",
|
| 1252 |
+
}
|
| 1253 |
+
except PyMongoError as e:
|
| 1254 |
+
logger.error(f"Error dropping vector storage {self._collection_name}: {e}")
|
| 1255 |
+
return {"status": "error", "message": str(e)}
|
| 1256 |
+
|
| 1257 |
|
| 1258 |
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
| 1259 |
collection_names = await db.list_collection_names()
|
lightrag/kg/nano_vector_db_impl.py
CHANGED
|
@@ -78,6 +78,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
| 78 |
return self._client
|
| 79 |
|
| 80 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
| 82 |
if not data:
|
| 83 |
return
|
|
@@ -146,6 +153,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
| 146 |
async def delete(self, ids: list[str]):
|
| 147 |
"""Delete vectors with specified IDs
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
Args:
|
| 150 |
ids: List of vector IDs to be deleted
|
| 151 |
"""
|
|
@@ -159,6 +171,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
| 159 |
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
| 160 |
|
| 161 |
async def delete_entity(self, entity_name: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
try:
|
| 163 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
| 164 |
logger.debug(
|
|
@@ -176,6 +195,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
| 176 |
logger.error(f"Error deleting entity {entity_name}: {e}")
|
| 177 |
|
| 178 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
try:
|
| 180 |
client = await self._get_client()
|
| 181 |
storage = getattr(client, "_NanoVectorDB__storage")
|
|
@@ -280,3 +306,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
| 280 |
|
| 281 |
client = await self._get_client()
|
| 282 |
return client.get(ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
return self._client
|
| 79 |
|
| 80 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 81 |
+
"""
|
| 82 |
+
Importance notes:
|
| 83 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 84 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 85 |
+
KG-storage-log should be used to avoid data corruption
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
| 89 |
if not data:
|
| 90 |
return
|
|
|
|
| 153 |
async def delete(self, ids: list[str]):
|
| 154 |
"""Delete vectors with specified IDs
|
| 155 |
|
| 156 |
+
Importance notes:
|
| 157 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 158 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 159 |
+
KG-storage-log should be used to avoid data corruption
|
| 160 |
+
|
| 161 |
Args:
|
| 162 |
ids: List of vector IDs to be deleted
|
| 163 |
"""
|
|
|
|
| 171 |
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
| 172 |
|
| 173 |
async def delete_entity(self, entity_name: str) -> None:
|
| 174 |
+
"""
|
| 175 |
+
Importance notes:
|
| 176 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 177 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 178 |
+
KG-storage-log should be used to avoid data corruption
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
try:
|
| 182 |
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
| 183 |
logger.debug(
|
|
|
|
| 195 |
logger.error(f"Error deleting entity {entity_name}: {e}")
|
| 196 |
|
| 197 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 198 |
+
"""
|
| 199 |
+
Importance notes:
|
| 200 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 201 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 202 |
+
KG-storage-log should be used to avoid data corruption
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
try:
|
| 206 |
client = await self._get_client()
|
| 207 |
storage = getattr(client, "_NanoVectorDB__storage")
|
|
|
|
| 306 |
|
| 307 |
client = await self._get_client()
|
| 308 |
return client.get(ids)
|
| 309 |
+
|
| 310 |
+
async def drop(self) -> dict[str, str]:
|
| 311 |
+
"""Drop all vector data from storage and clean up resources
|
| 312 |
+
|
| 313 |
+
This method will:
|
| 314 |
+
1. Remove the vector database storage file if it exists
|
| 315 |
+
2. Reinitialize the vector database client
|
| 316 |
+
3. Update flags to notify other processes
|
| 317 |
+
4. Changes is persisted to disk immediately
|
| 318 |
+
|
| 319 |
+
This method is intended for use in scenarios where all data needs to be removed,
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
dict[str, str]: Operation status and message
|
| 323 |
+
- On success: {"status": "success", "message": "data dropped"}
|
| 324 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
| 325 |
+
"""
|
| 326 |
+
try:
|
| 327 |
+
async with self._storage_lock:
|
| 328 |
+
# delete _client_file_name
|
| 329 |
+
if os.path.exists(self._client_file_name):
|
| 330 |
+
os.remove(self._client_file_name)
|
| 331 |
+
|
| 332 |
+
self._client = NanoVectorDB(
|
| 333 |
+
self.embedding_func.embedding_dim,
|
| 334 |
+
storage_file=self._client_file_name,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Notify other processes that data has been updated
|
| 338 |
+
await set_all_update_flags(self.namespace)
|
| 339 |
+
# Reset own update flag to avoid self-reloading
|
| 340 |
+
self.storage_updated.value = False
|
| 341 |
+
|
| 342 |
+
logger.info(
|
| 343 |
+
f"Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
|
| 344 |
+
)
|
| 345 |
+
return {"status": "success", "message": "data dropped"}
|
| 346 |
+
except Exception as e:
|
| 347 |
+
logger.error(f"Error dropping {self.namespace}: {e}")
|
| 348 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/neo4j_impl.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
import inspect
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
from dataclasses import dataclass
|
| 6 |
-
from typing import Any, final
|
| 7 |
import numpy as np
|
| 8 |
import configparser
|
| 9 |
|
|
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
|
|
| 29 |
exceptions as neo4jExceptions,
|
| 30 |
AsyncDriver,
|
| 31 |
AsyncManagedTransaction,
|
| 32 |
-
GraphDatabase,
|
| 33 |
)
|
| 34 |
|
| 35 |
config = configparser.ConfigParser()
|
|
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 52 |
embedding_func=embedding_func,
|
| 53 |
)
|
| 54 |
self._driver = None
|
| 55 |
-
self._driver_lock = asyncio.Lock()
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
| 58 |
USERNAME = os.environ.get(
|
| 59 |
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
|
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 86 |
),
|
| 87 |
)
|
| 88 |
DATABASE = os.environ.get(
|
| 89 |
-
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
|
| 90 |
)
|
| 91 |
|
| 92 |
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
|
@@ -98,71 +101,92 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 98 |
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
| 99 |
)
|
| 100 |
|
| 101 |
-
# Try to connect to the database
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
| 106 |
-
connection_timeout=CONNECTION_TIMEOUT,
|
| 107 |
-
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
|
| 108 |
-
) as _sync_driver:
|
| 109 |
-
for database in (DATABASE, None):
|
| 110 |
-
self._DATABASE = database
|
| 111 |
-
connected = False
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
raise e
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
)
|
|
|
|
|
|
|
|
|
|
| 132 |
try:
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
)
|
| 137 |
-
logger.info(f"{database} at {URI} created".capitalize())
|
| 138 |
-
connected = True
|
| 139 |
-
except (
|
| 140 |
-
neo4jExceptions.ClientError,
|
| 141 |
-
neo4jExceptions.DatabaseError,
|
| 142 |
-
) as e:
|
| 143 |
-
if (
|
| 144 |
-
e.code
|
| 145 |
-
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
| 146 |
-
) or (
|
| 147 |
-
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
| 148 |
-
):
|
| 149 |
-
if database is not None:
|
| 150 |
-
logger.warning(
|
| 151 |
-
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
| 152 |
-
)
|
| 153 |
-
if database is None:
|
| 154 |
-
logger.error(f"Failed to create {database} at {URI}")
|
| 155 |
-
raise e
|
| 156 |
|
| 157 |
-
|
| 158 |
-
break
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
-
async def
|
| 166 |
"""Close the Neo4j driver and release all resources"""
|
| 167 |
if self._driver:
|
| 168 |
await self._driver.close()
|
|
@@ -170,7 +194,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 170 |
|
| 171 |
async def __aexit__(self, exc_type, exc, tb):
|
| 172 |
"""Ensure driver is closed when context manager exits"""
|
| 173 |
-
await self.
|
| 174 |
|
| 175 |
async def index_done_callback(self) -> None:
|
| 176 |
# Noe4J handles persistence automatically
|
|
@@ -243,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 243 |
raise
|
| 244 |
|
| 245 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
| 246 |
-
"""Get node by its label identifier
|
| 247 |
|
| 248 |
Args:
|
| 249 |
node_id: The node label to look up
|
|
@@ -428,13 +452,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 428 |
logger.debug(
|
| 429 |
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
|
| 430 |
)
|
| 431 |
-
# Return
|
| 432 |
-
return
|
| 433 |
-
"weight": 0.0,
|
| 434 |
-
"source_id": None,
|
| 435 |
-
"description": None,
|
| 436 |
-
"keywords": None,
|
| 437 |
-
}
|
| 438 |
finally:
|
| 439 |
await result.consume() # Ensure result is fully consumed
|
| 440 |
|
|
@@ -526,7 +545,6 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 526 |
"""
|
| 527 |
properties = node_data
|
| 528 |
entity_type = properties["entity_type"]
|
| 529 |
-
entity_id = properties["entity_id"]
|
| 530 |
if "entity_id" not in properties:
|
| 531 |
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
|
| 532 |
|
|
@@ -536,15 +554,17 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 536 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
| 537 |
query = (
|
| 538 |
"""
|
| 539 |
-
MERGE (n:base {entity_id: $
|
| 540 |
SET n += $properties
|
| 541 |
SET n:`%s`
|
| 542 |
"""
|
| 543 |
% entity_type
|
| 544 |
)
|
| 545 |
-
result = await tx.run(
|
|
|
|
|
|
|
| 546 |
logger.debug(
|
| 547 |
-
f"Upserted node with entity_id '{
|
| 548 |
)
|
| 549 |
await result.consume() # Ensure result is fully consumed
|
| 550 |
|
|
@@ -622,25 +642,19 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 622 |
self,
|
| 623 |
node_label: str,
|
| 624 |
max_depth: int = 3,
|
| 625 |
-
|
| 626 |
-
inclusive: bool = False,
|
| 627 |
) -> KnowledgeGraph:
|
| 628 |
"""
|
| 629 |
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
| 630 |
-
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
| 631 |
-
When reducing the number of nodes, the prioritization criteria are as follows:
|
| 632 |
-
1. min_degree does not affect nodes directly connected to the matching nodes
|
| 633 |
-
2. Label matching nodes take precedence
|
| 634 |
-
3. Followed by nodes directly connected to the matching nodes
|
| 635 |
-
4. Finally, the degree of the nodes
|
| 636 |
|
| 637 |
Args:
|
| 638 |
-
node_label: Label of the starting node
|
| 639 |
-
max_depth: Maximum depth of the subgraph
|
| 640 |
-
|
| 641 |
-
|
| 642 |
Returns:
|
| 643 |
-
KnowledgeGraph
|
|
|
|
| 644 |
"""
|
| 645 |
result = KnowledgeGraph()
|
| 646 |
seen_nodes = set()
|
|
@@ -651,11 +665,27 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 651 |
) as session:
|
| 652 |
try:
|
| 653 |
if node_label == "*":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
main_query = """
|
| 655 |
MATCH (n)
|
| 656 |
OPTIONAL MATCH (n)-[r]-()
|
| 657 |
WITH n, COALESCE(count(r), 0) AS degree
|
| 658 |
-
WHERE degree >= $min_degree
|
| 659 |
ORDER BY degree DESC
|
| 660 |
LIMIT $max_nodes
|
| 661 |
WITH collect({node: n}) AS filtered_nodes
|
|
@@ -666,20 +696,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 666 |
RETURN filtered_nodes AS node_info,
|
| 667 |
collect(DISTINCT r) AS relationships
|
| 668 |
"""
|
| 669 |
-
result_set =
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 673 |
|
| 674 |
else:
|
| 675 |
-
#
|
| 676 |
-
|
|
|
|
| 677 |
MATCH (start)
|
| 678 |
-
WHERE
|
| 679 |
-
CASE
|
| 680 |
-
WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
|
| 681 |
-
ELSE start.entity_id = $entity_id
|
| 682 |
-
END
|
| 683 |
WITH start
|
| 684 |
CALL apoc.path.subgraphAll(start, {
|
| 685 |
relationshipFilter: '',
|
|
@@ -688,78 +721,115 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 688 |
bfs: true
|
| 689 |
})
|
| 690 |
YIELD nodes, relationships
|
| 691 |
-
WITH
|
| 692 |
UNWIND nodes AS node
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
|
| 696 |
-
ORDER BY
|
| 697 |
-
CASE
|
| 698 |
-
WHEN node = start THEN 3
|
| 699 |
-
WHEN EXISTS((start)--(node)) THEN 2
|
| 700 |
-
ELSE 1
|
| 701 |
-
END DESC,
|
| 702 |
-
degree DESC
|
| 703 |
-
LIMIT $max_nodes
|
| 704 |
-
WITH collect({node: node}) AS filtered_nodes
|
| 705 |
-
UNWIND filtered_nodes AS node_info
|
| 706 |
-
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
|
| 707 |
-
OPTIONAL MATCH (a)-[r]-(b)
|
| 708 |
-
WHERE a IN kept_nodes AND b IN kept_nodes
|
| 709 |
-
RETURN filtered_nodes AS node_info,
|
| 710 |
-
collect(DISTINCT r) AS relationships
|
| 711 |
"""
|
| 712 |
-
result_set = await session.run(
|
| 713 |
-
main_query,
|
| 714 |
-
{
|
| 715 |
-
"max_nodes": MAX_GRAPH_NODES,
|
| 716 |
-
"entity_id": node_label,
|
| 717 |
-
"inclusive": inclusive,
|
| 718 |
-
"max_depth": max_depth,
|
| 719 |
-
"min_degree": min_degree,
|
| 720 |
-
},
|
| 721 |
-
)
|
| 722 |
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
)
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
)
|
| 756 |
-
|
|
|
|
| 757 |
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
finally:
|
| 762 |
-
await result_set.consume() # Ensure result set is consumed
|
| 763 |
|
| 764 |
except neo4jExceptions.ClientError as e:
|
| 765 |
logger.warning(f"APOC plugin error: {str(e)}")
|
|
@@ -767,46 +837,89 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 767 |
logger.warning(
|
| 768 |
"Neo4j: falling back to basic Cypher recursive search..."
|
| 769 |
)
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
return await self._robust_fallback(
|
| 775 |
-
node_label, max_depth, min_degree
|
| 776 |
)
|
| 777 |
|
| 778 |
return result
|
| 779 |
|
| 780 |
async def _robust_fallback(
|
| 781 |
-
self, node_label: str, max_depth: int,
|
| 782 |
) -> KnowledgeGraph:
|
| 783 |
"""
|
| 784 |
Fallback implementation when APOC plugin is not available or incompatible.
|
| 785 |
This method implements the same functionality as get_knowledge_graph but uses
|
| 786 |
-
only basic Cypher queries and
|
| 787 |
"""
|
|
|
|
|
|
|
| 788 |
result = KnowledgeGraph()
|
| 789 |
visited_nodes = set()
|
| 790 |
visited_edges = set()
|
|
|
|
| 791 |
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
|
| 809 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 810 |
async with self._driver.session(
|
| 811 |
database=self._DATABASE, default_access_mode="READ"
|
| 812 |
) as session:
|
|
@@ -815,32 +928,17 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 815 |
WITH r, b, id(r) as edge_id, id(b) as target_id
|
| 816 |
RETURN r, b, edge_id, target_id
|
| 817 |
"""
|
| 818 |
-
results = await session.run(query, entity_id=
|
| 819 |
|
| 820 |
# Get all records and release database connection
|
| 821 |
-
records = await results.fetch(
|
| 822 |
-
1000
|
| 823 |
-
) # Max neighbour nodes we can handled
|
| 824 |
await results.consume() # Ensure results are consumed
|
| 825 |
|
| 826 |
-
#
|
| 827 |
-
if current_depth > 1 and len(records) < min_degree:
|
| 828 |
-
return
|
| 829 |
-
|
| 830 |
-
# Add current node to result
|
| 831 |
-
result.nodes.append(node)
|
| 832 |
-
visited_nodes.add(node.id)
|
| 833 |
-
|
| 834 |
-
# Add edge to result if it exists and not already added
|
| 835 |
-
if edge and edge.id not in visited_edges:
|
| 836 |
-
result.edges.append(edge)
|
| 837 |
-
visited_edges.add(edge.id)
|
| 838 |
-
|
| 839 |
-
# Prepare nodes and edges for recursive processing
|
| 840 |
-
nodes_to_process = []
|
| 841 |
for record in records:
|
| 842 |
rel = record["r"]
|
| 843 |
edge_id = str(record["edge_id"])
|
|
|
|
| 844 |
if edge_id not in visited_edges:
|
| 845 |
b_node = record["b"]
|
| 846 |
target_id = b_node.get("entity_id")
|
|
@@ -849,55 +947,59 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 849 |
# Create KnowledgeGraphNode for target
|
| 850 |
target_node = KnowledgeGraphNode(
|
| 851 |
id=f"{target_id}",
|
| 852 |
-
labels=
|
| 853 |
-
properties=dict(b_node.
|
| 854 |
)
|
| 855 |
|
| 856 |
# Create KnowledgeGraphEdge
|
| 857 |
target_edge = KnowledgeGraphEdge(
|
| 858 |
id=f"{edge_id}",
|
| 859 |
type=rel.type,
|
| 860 |
-
source=f"{
|
| 861 |
target=f"{target_id}",
|
| 862 |
properties=dict(rel),
|
| 863 |
)
|
| 864 |
|
| 865 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 866 |
else:
|
| 867 |
logger.warning(
|
| 868 |
-
f"Skipping edge {edge_id} due to missing
|
| 869 |
)
|
| 870 |
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
# Get the starting node's data
|
| 876 |
-
async with self._driver.session(
|
| 877 |
-
database=self._DATABASE, default_access_mode="READ"
|
| 878 |
-
) as session:
|
| 879 |
-
query = """
|
| 880 |
-
MATCH (n:base {entity_id: $entity_id})
|
| 881 |
-
RETURN id(n) as node_id, n
|
| 882 |
-
"""
|
| 883 |
-
node_result = await session.run(query, entity_id=node_label)
|
| 884 |
-
try:
|
| 885 |
-
node_record = await node_result.single()
|
| 886 |
-
if not node_record:
|
| 887 |
-
return result
|
| 888 |
-
|
| 889 |
-
# Create initial KnowledgeGraphNode
|
| 890 |
-
start_node = KnowledgeGraphNode(
|
| 891 |
-
id=f"{node_record['n'].get('entity_id')}",
|
| 892 |
-
labels=list(f"{node_record['n'].get('entity_id')}"),
|
| 893 |
-
properties=dict(node_record["n"].properties),
|
| 894 |
-
)
|
| 895 |
-
finally:
|
| 896 |
-
await node_result.consume() # Ensure results are consumed
|
| 897 |
-
|
| 898 |
-
# Start traversal with the initial node
|
| 899 |
-
await traverse(start_node, None, 0)
|
| 900 |
-
|
| 901 |
return result
|
| 902 |
|
| 903 |
async def get_all_labels(self) -> list[str]:
|
|
@@ -914,7 +1016,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 914 |
|
| 915 |
# Method 2: Query compatible with older versions
|
| 916 |
query = """
|
| 917 |
-
MATCH (n)
|
| 918 |
WHERE n.entity_id IS NOT NULL
|
| 919 |
RETURN DISTINCT n.entity_id AS label
|
| 920 |
ORDER BY label
|
|
@@ -1028,3 +1130,28 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 1028 |
self, algorithm: str
|
| 1029 |
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
| 1030 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import inspect
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, final
|
| 6 |
import numpy as np
|
| 7 |
import configparser
|
| 8 |
|
|
|
|
| 28 |
exceptions as neo4jExceptions,
|
| 29 |
AsyncDriver,
|
| 30 |
AsyncManagedTransaction,
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
config = configparser.ConfigParser()
|
|
|
|
| 50 |
embedding_func=embedding_func,
|
| 51 |
)
|
| 52 |
self._driver = None
|
|
|
|
| 53 |
|
| 54 |
+
def __post_init__(self):
|
| 55 |
+
self._node_embed_algorithms = {
|
| 56 |
+
"node2vec": self._node2vec_embed,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
async def initialize(self):
|
| 60 |
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
| 61 |
USERNAME = os.environ.get(
|
| 62 |
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
|
|
|
| 89 |
),
|
| 90 |
)
|
| 91 |
DATABASE = os.environ.get(
|
| 92 |
+
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
|
| 93 |
)
|
| 94 |
|
| 95 |
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
|
|
|
| 101 |
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
| 102 |
)
|
| 103 |
|
| 104 |
+
# Try to connect to the database and create it if it doesn't exist
|
| 105 |
+
for database in (DATABASE, None):
|
| 106 |
+
self._DATABASE = database
|
| 107 |
+
connected = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
try:
|
| 110 |
+
async with self._driver.session(database=database) as session:
|
| 111 |
+
try:
|
| 112 |
+
result = await session.run("MATCH (n) RETURN n LIMIT 0")
|
| 113 |
+
await result.consume() # Ensure result is consumed
|
| 114 |
+
logger.info(f"Connected to {database} at {URI}")
|
| 115 |
+
connected = True
|
| 116 |
+
except neo4jExceptions.ServiceUnavailable as e:
|
| 117 |
+
logger.error(
|
| 118 |
+
f"{database} at {URI} is not available".capitalize()
|
| 119 |
+
)
|
| 120 |
+
raise e
|
| 121 |
+
except neo4jExceptions.AuthError as e:
|
| 122 |
+
logger.error(f"Authentication failed for {database} at {URI}")
|
| 123 |
+
raise e
|
| 124 |
+
except neo4jExceptions.ClientError as e:
|
| 125 |
+
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
| 126 |
+
logger.info(
|
| 127 |
+
f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
| 128 |
+
)
|
| 129 |
+
try:
|
| 130 |
+
async with self._driver.session() as session:
|
| 131 |
+
result = await session.run(
|
| 132 |
+
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
| 133 |
)
|
| 134 |
+
await result.consume() # Ensure result is consumed
|
| 135 |
+
logger.info(f"{database} at {URI} created".capitalize())
|
| 136 |
+
connected = True
|
| 137 |
+
except (
|
| 138 |
+
neo4jExceptions.ClientError,
|
| 139 |
+
neo4jExceptions.DatabaseError,
|
| 140 |
+
) as e:
|
| 141 |
+
if (
|
| 142 |
+
e.code
|
| 143 |
+
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
| 144 |
+
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
|
| 145 |
+
if database is not None:
|
| 146 |
+
logger.warning(
|
| 147 |
+
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
| 148 |
+
)
|
| 149 |
+
if database is None:
|
| 150 |
+
logger.error(f"Failed to create {database} at {URI}")
|
| 151 |
raise e
|
| 152 |
+
|
| 153 |
+
if connected:
|
| 154 |
+
# Create index for base nodes on entity_id if it doesn't exist
|
| 155 |
+
try:
|
| 156 |
+
async with self._driver.session(database=database) as session:
|
| 157 |
+
# Check if index exists first
|
| 158 |
+
check_query = """
|
| 159 |
+
CALL db.indexes() YIELD name, labelsOrTypes, properties
|
| 160 |
+
WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
|
| 161 |
+
RETURN count(*) > 0 AS exists
|
| 162 |
+
"""
|
| 163 |
try:
|
| 164 |
+
check_result = await session.run(check_query)
|
| 165 |
+
record = await check_result.single()
|
| 166 |
+
await check_result.consume()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
+
index_exists = record and record.get("exists", False)
|
|
|
|
| 169 |
|
| 170 |
+
if not index_exists:
|
| 171 |
+
# Create index only if it doesn't exist
|
| 172 |
+
result = await session.run(
|
| 173 |
+
"CREATE INDEX FOR (n:base) ON (n.entity_id)"
|
| 174 |
+
)
|
| 175 |
+
await result.consume()
|
| 176 |
+
logger.info(
|
| 177 |
+
f"Created index for base nodes on entity_id in {database}"
|
| 178 |
+
)
|
| 179 |
+
except Exception:
|
| 180 |
+
# Fallback if db.indexes() is not supported in this Neo4j version
|
| 181 |
+
result = await session.run(
|
| 182 |
+
"CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
|
| 183 |
+
)
|
| 184 |
+
await result.consume()
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.warning(f"Failed to create index: {str(e)}")
|
| 187 |
+
break
|
| 188 |
|
| 189 |
+
async def finalize(self):
|
| 190 |
"""Close the Neo4j driver and release all resources"""
|
| 191 |
if self._driver:
|
| 192 |
await self._driver.close()
|
|
|
|
| 194 |
|
| 195 |
async def __aexit__(self, exc_type, exc, tb):
|
| 196 |
"""Ensure driver is closed when context manager exits"""
|
| 197 |
+
await self.finalize()
|
| 198 |
|
| 199 |
async def index_done_callback(self) -> None:
|
| 200 |
# Noe4J handles persistence automatically
|
|
|
|
| 267 |
raise
|
| 268 |
|
| 269 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
| 270 |
+
"""Get node by its label identifier, return only node properties
|
| 271 |
|
| 272 |
Args:
|
| 273 |
node_id: The node label to look up
|
|
|
|
| 452 |
logger.debug(
|
| 453 |
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
|
| 454 |
)
|
| 455 |
+
# Return None when no edge found
|
| 456 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
finally:
|
| 458 |
await result.consume() # Ensure result is fully consumed
|
| 459 |
|
|
|
|
| 545 |
"""
|
| 546 |
properties = node_data
|
| 547 |
entity_type = properties["entity_type"]
|
|
|
|
| 548 |
if "entity_id" not in properties:
|
| 549 |
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
|
| 550 |
|
|
|
|
| 554 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
| 555 |
query = (
|
| 556 |
"""
|
| 557 |
+
MERGE (n:base {entity_id: $entity_id})
|
| 558 |
SET n += $properties
|
| 559 |
SET n:`%s`
|
| 560 |
"""
|
| 561 |
% entity_type
|
| 562 |
)
|
| 563 |
+
result = await tx.run(
|
| 564 |
+
query, entity_id=node_id, properties=properties
|
| 565 |
+
)
|
| 566 |
logger.debug(
|
| 567 |
+
f"Upserted node with entity_id '{node_id}' and properties: {properties}"
|
| 568 |
)
|
| 569 |
await result.consume() # Ensure result is fully consumed
|
| 570 |
|
|
|
|
| 642 |
self,
|
| 643 |
node_label: str,
|
| 644 |
max_depth: int = 3,
|
| 645 |
+
max_nodes: int = MAX_GRAPH_NODES,
|
|
|
|
| 646 |
) -> KnowledgeGraph:
|
| 647 |
"""
|
| 648 |
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
|
| 650 |
Args:
|
| 651 |
+
node_label: Label of the starting node, * means all nodes
|
| 652 |
+
max_depth: Maximum depth of the subgraph, Defaults to 3
|
| 653 |
+
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
| 654 |
+
|
| 655 |
Returns:
|
| 656 |
+
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
| 657 |
+
indicating whether the graph was truncated due to max_nodes limit
|
| 658 |
"""
|
| 659 |
result = KnowledgeGraph()
|
| 660 |
seen_nodes = set()
|
|
|
|
| 665 |
) as session:
|
| 666 |
try:
|
| 667 |
if node_label == "*":
|
| 668 |
+
# First check total node count to determine if graph is truncated
|
| 669 |
+
count_query = "MATCH (n) RETURN count(n) as total"
|
| 670 |
+
count_result = None
|
| 671 |
+
try:
|
| 672 |
+
count_result = await session.run(count_query)
|
| 673 |
+
count_record = await count_result.single()
|
| 674 |
+
|
| 675 |
+
if count_record and count_record["total"] > max_nodes:
|
| 676 |
+
result.is_truncated = True
|
| 677 |
+
logger.info(
|
| 678 |
+
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
|
| 679 |
+
)
|
| 680 |
+
finally:
|
| 681 |
+
if count_result:
|
| 682 |
+
await count_result.consume()
|
| 683 |
+
|
| 684 |
+
# Run main query to get nodes with highest degree
|
| 685 |
main_query = """
|
| 686 |
MATCH (n)
|
| 687 |
OPTIONAL MATCH (n)-[r]-()
|
| 688 |
WITH n, COALESCE(count(r), 0) AS degree
|
|
|
|
| 689 |
ORDER BY degree DESC
|
| 690 |
LIMIT $max_nodes
|
| 691 |
WITH collect({node: n}) AS filtered_nodes
|
|
|
|
| 696 |
RETURN filtered_nodes AS node_info,
|
| 697 |
collect(DISTINCT r) AS relationships
|
| 698 |
"""
|
| 699 |
+
result_set = None
|
| 700 |
+
try:
|
| 701 |
+
result_set = await session.run(
|
| 702 |
+
main_query,
|
| 703 |
+
{"max_nodes": max_nodes},
|
| 704 |
+
)
|
| 705 |
+
record = await result_set.single()
|
| 706 |
+
finally:
|
| 707 |
+
if result_set:
|
| 708 |
+
await result_set.consume()
|
| 709 |
|
| 710 |
else:
|
| 711 |
+
# return await self._robust_fallback(node_label, max_depth, max_nodes)
|
| 712 |
+
# First try without limit to check if we need to truncate
|
| 713 |
+
full_query = """
|
| 714 |
MATCH (start)
|
| 715 |
+
WHERE start.entity_id = $entity_id
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
WITH start
|
| 717 |
CALL apoc.path.subgraphAll(start, {
|
| 718 |
relationshipFilter: '',
|
|
|
|
| 721 |
bfs: true
|
| 722 |
})
|
| 723 |
YIELD nodes, relationships
|
| 724 |
+
WITH nodes, relationships, size(nodes) AS total_nodes
|
| 725 |
UNWIND nodes AS node
|
| 726 |
+
WITH collect({node: node}) AS node_info, relationships, total_nodes
|
| 727 |
+
RETURN node_info, relationships, total_nodes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
|
| 730 |
+
# Try to get full result
|
| 731 |
+
full_result = None
|
| 732 |
+
try:
|
| 733 |
+
full_result = await session.run(
|
| 734 |
+
full_query,
|
| 735 |
+
{
|
| 736 |
+
"entity_id": node_label,
|
| 737 |
+
"max_depth": max_depth,
|
| 738 |
+
},
|
| 739 |
+
)
|
| 740 |
+
full_record = await full_result.single()
|
| 741 |
+
|
| 742 |
+
# If no record found, return empty KnowledgeGraph
|
| 743 |
+
if not full_record:
|
| 744 |
+
logger.debug(f"No nodes found for entity_id: {node_label}")
|
| 745 |
+
return result
|
| 746 |
+
|
| 747 |
+
# If record found, check node count
|
| 748 |
+
total_nodes = full_record["total_nodes"]
|
| 749 |
+
|
| 750 |
+
if total_nodes <= max_nodes:
|
| 751 |
+
# If node count is within limit, use full result directly
|
| 752 |
+
logger.debug(
|
| 753 |
+
f"Using full result with {total_nodes} nodes (no truncation needed)"
|
| 754 |
+
)
|
| 755 |
+
record = full_record
|
| 756 |
+
else:
|
| 757 |
+
# If node count exceeds limit, set truncated flag and run limited query
|
| 758 |
+
result.is_truncated = True
|
| 759 |
+
logger.info(
|
| 760 |
+
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
# Run limited query
|
| 764 |
+
limited_query = """
|
| 765 |
+
MATCH (start)
|
| 766 |
+
WHERE start.entity_id = $entity_id
|
| 767 |
+
WITH start
|
| 768 |
+
CALL apoc.path.subgraphAll(start, {
|
| 769 |
+
relationshipFilter: '',
|
| 770 |
+
minLevel: 0,
|
| 771 |
+
maxLevel: $max_depth,
|
| 772 |
+
limit: $max_nodes,
|
| 773 |
+
bfs: true
|
| 774 |
+
})
|
| 775 |
+
YIELD nodes, relationships
|
| 776 |
+
UNWIND nodes AS node
|
| 777 |
+
WITH collect({node: node}) AS node_info, relationships
|
| 778 |
+
RETURN node_info, relationships
|
| 779 |
+
"""
|
| 780 |
+
result_set = None
|
| 781 |
+
try:
|
| 782 |
+
result_set = await session.run(
|
| 783 |
+
limited_query,
|
| 784 |
+
{
|
| 785 |
+
"entity_id": node_label,
|
| 786 |
+
"max_depth": max_depth,
|
| 787 |
+
"max_nodes": max_nodes,
|
| 788 |
+
},
|
| 789 |
)
|
| 790 |
+
record = await result_set.single()
|
| 791 |
+
finally:
|
| 792 |
+
if result_set:
|
| 793 |
+
await result_set.consume()
|
| 794 |
+
finally:
|
| 795 |
+
if full_result:
|
| 796 |
+
await full_result.consume()
|
| 797 |
+
|
| 798 |
+
if record:
|
| 799 |
+
# Handle nodes (compatible with multi-label cases)
|
| 800 |
+
for node_info in record["node_info"]:
|
| 801 |
+
node = node_info["node"]
|
| 802 |
+
node_id = node.id
|
| 803 |
+
if node_id not in seen_nodes:
|
| 804 |
+
result.nodes.append(
|
| 805 |
+
KnowledgeGraphNode(
|
| 806 |
+
id=f"{node_id}",
|
| 807 |
+
labels=[node.get("entity_id")],
|
| 808 |
+
properties=dict(node),
|
| 809 |
+
)
|
| 810 |
+
)
|
| 811 |
+
seen_nodes.add(node_id)
|
| 812 |
+
|
| 813 |
+
# Handle relationships (including direction information)
|
| 814 |
+
for rel in record["relationships"]:
|
| 815 |
+
edge_id = rel.id
|
| 816 |
+
if edge_id not in seen_edges:
|
| 817 |
+
start = rel.start_node
|
| 818 |
+
end = rel.end_node
|
| 819 |
+
result.edges.append(
|
| 820 |
+
KnowledgeGraphEdge(
|
| 821 |
+
id=f"{edge_id}",
|
| 822 |
+
type=rel.type,
|
| 823 |
+
source=f"{start.id}",
|
| 824 |
+
target=f"{end.id}",
|
| 825 |
+
properties=dict(rel),
|
| 826 |
)
|
| 827 |
+
)
|
| 828 |
+
seen_edges.add(edge_id)
|
| 829 |
|
| 830 |
+
logger.info(
|
| 831 |
+
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
| 832 |
+
)
|
|
|
|
|
|
|
| 833 |
|
| 834 |
except neo4jExceptions.ClientError as e:
|
| 835 |
logger.warning(f"APOC plugin error: {str(e)}")
|
|
|
|
| 837 |
logger.warning(
|
| 838 |
"Neo4j: falling back to basic Cypher recursive search..."
|
| 839 |
)
|
| 840 |
+
return await self._robust_fallback(node_label, max_depth, max_nodes)
|
| 841 |
+
else:
|
| 842 |
+
logger.warning(
|
| 843 |
+
"Neo4j: APOC plugin error with wildcard query, returning empty result"
|
|
|
|
|
|
|
| 844 |
)
|
| 845 |
|
| 846 |
return result
|
| 847 |
|
| 848 |
async def _robust_fallback(
|
| 849 |
+
self, node_label: str, max_depth: int, max_nodes: int
|
| 850 |
) -> KnowledgeGraph:
|
| 851 |
"""
|
| 852 |
Fallback implementation when APOC plugin is not available or incompatible.
|
| 853 |
This method implements the same functionality as get_knowledge_graph but uses
|
| 854 |
+
only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
|
| 855 |
"""
|
| 856 |
+
from collections import deque
|
| 857 |
+
|
| 858 |
result = KnowledgeGraph()
|
| 859 |
visited_nodes = set()
|
| 860 |
visited_edges = set()
|
| 861 |
+
visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
|
| 862 |
|
| 863 |
+
# Get the starting node's data
|
| 864 |
+
async with self._driver.session(
|
| 865 |
+
database=self._DATABASE, default_access_mode="READ"
|
| 866 |
+
) as session:
|
| 867 |
+
query = """
|
| 868 |
+
MATCH (n:base {entity_id: $entity_id})
|
| 869 |
+
RETURN id(n) as node_id, n
|
| 870 |
+
"""
|
| 871 |
+
node_result = await session.run(query, entity_id=node_label)
|
| 872 |
+
try:
|
| 873 |
+
node_record = await node_result.single()
|
| 874 |
+
if not node_record:
|
| 875 |
+
return result
|
| 876 |
+
|
| 877 |
+
# Create initial KnowledgeGraphNode
|
| 878 |
+
start_node = KnowledgeGraphNode(
|
| 879 |
+
id=f"{node_record['n'].get('entity_id')}",
|
| 880 |
+
labels=[node_record["n"].get("entity_id")],
|
| 881 |
+
properties=dict(node_record["n"]._properties),
|
| 882 |
+
)
|
| 883 |
+
finally:
|
| 884 |
+
await node_result.consume() # Ensure results are consumed
|
| 885 |
|
| 886 |
+
# Initialize queue for BFS with (node, edge, depth) tuples
|
| 887 |
+
# edge is None for the starting node
|
| 888 |
+
queue = deque([(start_node, None, 0)])
|
| 889 |
|
| 890 |
+
# True BFS implementation using a queue
|
| 891 |
+
while queue and len(visited_nodes) < max_nodes:
|
| 892 |
+
# Dequeue the next node to process
|
| 893 |
+
current_node, current_edge, current_depth = queue.popleft()
|
| 894 |
+
|
| 895 |
+
# Skip if already visited or exceeds max depth
|
| 896 |
+
if current_node.id in visited_nodes:
|
| 897 |
+
continue
|
| 898 |
+
|
| 899 |
+
if current_depth > max_depth:
|
| 900 |
+
logger.debug(
|
| 901 |
+
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
|
| 902 |
+
)
|
| 903 |
+
continue
|
| 904 |
+
|
| 905 |
+
# Add current node to result
|
| 906 |
+
result.nodes.append(current_node)
|
| 907 |
+
visited_nodes.add(current_node.id)
|
| 908 |
+
|
| 909 |
+
# Add edge to result if it exists and not already added
|
| 910 |
+
if current_edge and current_edge.id not in visited_edges:
|
| 911 |
+
result.edges.append(current_edge)
|
| 912 |
+
visited_edges.add(current_edge.id)
|
| 913 |
+
|
| 914 |
+
# Stop if we've reached the node limit
|
| 915 |
+
if len(visited_nodes) >= max_nodes:
|
| 916 |
+
result.is_truncated = True
|
| 917 |
+
logger.info(
|
| 918 |
+
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
| 919 |
+
)
|
| 920 |
+
break
|
| 921 |
+
|
| 922 |
+
# Get all edges and target nodes for the current node (even at max_depth)
|
| 923 |
async with self._driver.session(
|
| 924 |
database=self._DATABASE, default_access_mode="READ"
|
| 925 |
) as session:
|
|
|
|
| 928 |
WITH r, b, id(r) as edge_id, id(b) as target_id
|
| 929 |
RETURN r, b, edge_id, target_id
|
| 930 |
"""
|
| 931 |
+
results = await session.run(query, entity_id=current_node.id)
|
| 932 |
|
| 933 |
# Get all records and release database connection
|
| 934 |
+
records = await results.fetch(1000) # Max neighbor nodes we can handle
|
|
|
|
|
|
|
| 935 |
await results.consume() # Ensure results are consumed
|
| 936 |
|
| 937 |
+
# Process all neighbors - capture all edges but only queue unvisited nodes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
for record in records:
|
| 939 |
rel = record["r"]
|
| 940 |
edge_id = str(record["edge_id"])
|
| 941 |
+
|
| 942 |
if edge_id not in visited_edges:
|
| 943 |
b_node = record["b"]
|
| 944 |
target_id = b_node.get("entity_id")
|
|
|
|
| 947 |
# Create KnowledgeGraphNode for target
|
| 948 |
target_node = KnowledgeGraphNode(
|
| 949 |
id=f"{target_id}",
|
| 950 |
+
labels=[target_id],
|
| 951 |
+
properties=dict(b_node._properties),
|
| 952 |
)
|
| 953 |
|
| 954 |
# Create KnowledgeGraphEdge
|
| 955 |
target_edge = KnowledgeGraphEdge(
|
| 956 |
id=f"{edge_id}",
|
| 957 |
type=rel.type,
|
| 958 |
+
source=f"{current_node.id}",
|
| 959 |
target=f"{target_id}",
|
| 960 |
properties=dict(rel),
|
| 961 |
)
|
| 962 |
|
| 963 |
+
# 对source_id和target_id进行排序,确保(A,B)和(B,A)被视为同一条边
|
| 964 |
+
sorted_pair = tuple(sorted([current_node.id, target_id]))
|
| 965 |
+
|
| 966 |
+
# 检查是否已存在相同的边(考虑无向性)
|
| 967 |
+
if sorted_pair not in visited_edge_pairs:
|
| 968 |
+
# 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
|
| 969 |
+
if target_id in visited_nodes or (
|
| 970 |
+
target_id not in visited_nodes
|
| 971 |
+
and current_depth < max_depth
|
| 972 |
+
):
|
| 973 |
+
result.edges.append(target_edge)
|
| 974 |
+
visited_edges.add(edge_id)
|
| 975 |
+
visited_edge_pairs.add(sorted_pair)
|
| 976 |
+
|
| 977 |
+
# Only add unvisited nodes to the queue for further expansion
|
| 978 |
+
if target_id not in visited_nodes:
|
| 979 |
+
# Only add to queue if we're not at max depth yet
|
| 980 |
+
if current_depth < max_depth:
|
| 981 |
+
# Add node to queue with incremented depth
|
| 982 |
+
# Edge is already added to result, so we pass None as edge
|
| 983 |
+
queue.append((target_node, None, current_depth + 1))
|
| 984 |
+
else:
|
| 985 |
+
# At max depth, we've already added the edge but we don't add the node
|
| 986 |
+
# This prevents adding nodes beyond max_depth to the result
|
| 987 |
+
logger.debug(
|
| 988 |
+
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
|
| 989 |
+
)
|
| 990 |
+
else:
|
| 991 |
+
# If target node already exists in result, we don't need to add it again
|
| 992 |
+
logger.debug(
|
| 993 |
+
f"Node {target_id} already visited, edge added but node not queued"
|
| 994 |
+
)
|
| 995 |
else:
|
| 996 |
logger.warning(
|
| 997 |
+
f"Skipping edge {edge_id} due to missing entity_id on target node"
|
| 998 |
)
|
| 999 |
|
| 1000 |
+
logger.info(
|
| 1001 |
+
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
| 1002 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1003 |
return result
|
| 1004 |
|
| 1005 |
async def get_all_labels(self) -> list[str]:
|
|
|
|
| 1016 |
|
| 1017 |
# Method 2: Query compatible with older versions
|
| 1018 |
query = """
|
| 1019 |
+
MATCH (n:base)
|
| 1020 |
WHERE n.entity_id IS NOT NULL
|
| 1021 |
RETURN DISTINCT n.entity_id AS label
|
| 1022 |
ORDER BY label
|
|
|
|
| 1130 |
self, algorithm: str
|
| 1131 |
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
| 1132 |
raise NotImplementedError
|
| 1133 |
+
|
| 1134 |
+
async def drop(self) -> dict[str, str]:
|
| 1135 |
+
"""Drop all data from storage and clean up resources
|
| 1136 |
+
|
| 1137 |
+
This method will delete all nodes and relationships in the Neo4j database.
|
| 1138 |
+
|
| 1139 |
+
Returns:
|
| 1140 |
+
dict[str, str]: Operation status and message
|
| 1141 |
+
- On success: {"status": "success", "message": "data dropped"}
|
| 1142 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
| 1143 |
+
"""
|
| 1144 |
+
try:
|
| 1145 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
| 1146 |
+
# Delete all nodes and relationships
|
| 1147 |
+
query = "MATCH (n) DETACH DELETE n"
|
| 1148 |
+
result = await session.run(query)
|
| 1149 |
+
await result.consume() # Ensure result is fully consumed
|
| 1150 |
+
|
| 1151 |
+
logger.info(
|
| 1152 |
+
f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
|
| 1153 |
+
)
|
| 1154 |
+
return {"status": "success", "message": "data dropped"}
|
| 1155 |
+
except Exception as e:
|
| 1156 |
+
logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
|
| 1157 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/networkx_impl.py
CHANGED
|
@@ -42,6 +42,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 42 |
)
|
| 43 |
nx.write_graphml(graph, file_name)
|
| 44 |
|
|
|
|
| 45 |
@staticmethod
|
| 46 |
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
| 47 |
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
|
@@ -155,16 +156,34 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 155 |
return None
|
| 156 |
|
| 157 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
graph = await self._get_graph()
|
| 159 |
graph.add_node(node_id, **node_data)
|
| 160 |
|
| 161 |
async def upsert_edge(
|
| 162 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 163 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
graph = await self._get_graph()
|
| 165 |
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
| 166 |
|
| 167 |
async def delete_node(self, node_id: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
graph = await self._get_graph()
|
| 169 |
if graph.has_node(node_id):
|
| 170 |
graph.remove_node(node_id)
|
|
@@ -172,6 +191,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 172 |
else:
|
| 173 |
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
| 174 |
|
|
|
|
| 175 |
async def embed_nodes(
|
| 176 |
self, algorithm: str
|
| 177 |
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
|
@@ -192,6 +212,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 192 |
async def remove_nodes(self, nodes: list[str]):
|
| 193 |
"""Delete multiple nodes
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
Args:
|
| 196 |
nodes: List of node IDs to be deleted
|
| 197 |
"""
|
|
@@ -203,6 +228,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 203 |
async def remove_edges(self, edges: list[tuple[str, str]]):
|
| 204 |
"""Delete multiple edges
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
Args:
|
| 207 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
| 208 |
"""
|
|
@@ -229,118 +259,81 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 229 |
self,
|
| 230 |
node_label: str,
|
| 231 |
max_depth: int = 3,
|
| 232 |
-
|
| 233 |
-
inclusive: bool = False,
|
| 234 |
) -> KnowledgeGraph:
|
| 235 |
"""
|
| 236 |
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
| 237 |
-
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
| 238 |
-
When reducing the number of nodes, the prioritization criteria are as follows:
|
| 239 |
-
1. min_degree does not affect nodes directly connected to the matching nodes
|
| 240 |
-
2. Label matching nodes take precedence
|
| 241 |
-
3. Followed by nodes directly connected to the matching nodes
|
| 242 |
-
4. Finally, the degree of the nodes
|
| 243 |
|
| 244 |
Args:
|
| 245 |
-
node_label: Label of the starting node
|
| 246 |
-
max_depth: Maximum depth of the subgraph
|
| 247 |
-
|
| 248 |
-
inclusive: Do an inclusive search if true
|
| 249 |
|
| 250 |
Returns:
|
| 251 |
-
KnowledgeGraph object containing nodes and edges
|
|
|
|
| 252 |
"""
|
| 253 |
-
result = KnowledgeGraph()
|
| 254 |
-
seen_nodes = set()
|
| 255 |
-
seen_edges = set()
|
| 256 |
-
|
| 257 |
graph = await self._get_graph()
|
| 258 |
|
| 259 |
-
|
| 260 |
-
start_nodes = set()
|
| 261 |
-
direct_connected_nodes = set()
|
| 262 |
|
| 263 |
# Handle special case for "*" label
|
| 264 |
if node_label == "*":
|
| 265 |
-
#
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
else:
|
| 270 |
-
#
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
subgraph
|
| 304 |
-
|
| 305 |
-
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
|
| 306 |
-
if min_degree > 0:
|
| 307 |
-
nodes_to_keep = [
|
| 308 |
-
node
|
| 309 |
-
for node, degree in subgraph.degree()
|
| 310 |
-
if node in start_nodes
|
| 311 |
-
or node in direct_connected_nodes
|
| 312 |
-
or degree >= min_degree
|
| 313 |
-
]
|
| 314 |
-
subgraph = subgraph.subgraph(nodes_to_keep)
|
| 315 |
-
|
| 316 |
-
# Check if number of nodes exceeds max_graph_nodes
|
| 317 |
-
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
| 318 |
-
origin_nodes = len(subgraph.nodes())
|
| 319 |
-
node_degrees = dict(subgraph.degree())
|
| 320 |
-
|
| 321 |
-
def priority_key(node_item):
|
| 322 |
-
node, degree = node_item
|
| 323 |
-
# Priority order: start(2) > directly connected(1) > other nodes(0)
|
| 324 |
-
if node in start_nodes:
|
| 325 |
-
priority = 2
|
| 326 |
-
elif node in direct_connected_nodes:
|
| 327 |
-
priority = 1
|
| 328 |
-
else:
|
| 329 |
-
priority = 0
|
| 330 |
-
return (priority, degree)
|
| 331 |
-
|
| 332 |
-
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
|
| 333 |
-
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
|
| 334 |
-
:MAX_GRAPH_NODES
|
| 335 |
-
]
|
| 336 |
-
top_node_ids = [node[0] for node in top_nodes]
|
| 337 |
-
# Create new subgraph and keep nodes only with most degree
|
| 338 |
-
subgraph = subgraph.subgraph(top_node_ids)
|
| 339 |
-
logger.info(
|
| 340 |
-
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
|
| 341 |
-
)
|
| 342 |
|
| 343 |
# Add nodes to result
|
|
|
|
|
|
|
| 344 |
for node in subgraph.nodes():
|
| 345 |
if str(node) in seen_nodes:
|
| 346 |
continue
|
|
@@ -368,7 +361,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 368 |
for edge in subgraph.edges():
|
| 369 |
source, target = edge
|
| 370 |
# Esure unique edge_id for undirect graph
|
| 371 |
-
if source > target:
|
| 372 |
source, target = target, source
|
| 373 |
edge_id = f"{source}-{target}"
|
| 374 |
if edge_id in seen_edges:
|
|
@@ -424,3 +417,35 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 424 |
return False # Return error
|
| 425 |
|
| 426 |
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
)
|
| 43 |
nx.write_graphml(graph, file_name)
|
| 44 |
|
| 45 |
+
# TODO:deprecated, remove later
|
| 46 |
@staticmethod
|
| 47 |
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
| 48 |
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
|
|
|
| 156 |
return None
|
| 157 |
|
| 158 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
| 159 |
+
"""
|
| 160 |
+
Importance notes:
|
| 161 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 162 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 163 |
+
KG-storage-log should be used to avoid data corruption
|
| 164 |
+
"""
|
| 165 |
graph = await self._get_graph()
|
| 166 |
graph.add_node(node_id, **node_data)
|
| 167 |
|
| 168 |
async def upsert_edge(
|
| 169 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 170 |
) -> None:
|
| 171 |
+
"""
|
| 172 |
+
Importance notes:
|
| 173 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 174 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 175 |
+
KG-storage-log should be used to avoid data corruption
|
| 176 |
+
"""
|
| 177 |
graph = await self._get_graph()
|
| 178 |
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
| 179 |
|
| 180 |
async def delete_node(self, node_id: str) -> None:
|
| 181 |
+
"""
|
| 182 |
+
Importance notes:
|
| 183 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 184 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 185 |
+
KG-storage-log should be used to avoid data corruption
|
| 186 |
+
"""
|
| 187 |
graph = await self._get_graph()
|
| 188 |
if graph.has_node(node_id):
|
| 189 |
graph.remove_node(node_id)
|
|
|
|
| 191 |
else:
|
| 192 |
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
| 193 |
|
| 194 |
+
# TODO: NOT USED
|
| 195 |
async def embed_nodes(
|
| 196 |
self, algorithm: str
|
| 197 |
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
|
|
|
| 212 |
async def remove_nodes(self, nodes: list[str]):
|
| 213 |
"""Delete multiple nodes
|
| 214 |
|
| 215 |
+
Importance notes:
|
| 216 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 217 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 218 |
+
KG-storage-log should be used to avoid data corruption
|
| 219 |
+
|
| 220 |
Args:
|
| 221 |
nodes: List of node IDs to be deleted
|
| 222 |
"""
|
|
|
|
| 228 |
async def remove_edges(self, edges: list[tuple[str, str]]):
|
| 229 |
"""Delete multiple edges
|
| 230 |
|
| 231 |
+
Importance notes:
|
| 232 |
+
1. Changes will be persisted to disk during the next index_done_callback
|
| 233 |
+
2. Only one process should updating the storage at a time before index_done_callback,
|
| 234 |
+
KG-storage-log should be used to avoid data corruption
|
| 235 |
+
|
| 236 |
Args:
|
| 237 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
| 238 |
"""
|
|
|
|
| 259 |
self,
|
| 260 |
node_label: str,
|
| 261 |
max_depth: int = 3,
|
| 262 |
+
max_nodes: int = MAX_GRAPH_NODES,
|
|
|
|
| 263 |
) -> KnowledgeGraph:
|
| 264 |
"""
|
| 265 |
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
Args:
|
| 268 |
+
node_label: Label of the starting node,* means all nodes
|
| 269 |
+
max_depth: Maximum depth of the subgraph, Defaults to 3
|
| 270 |
+
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
|
|
|
| 271 |
|
| 272 |
Returns:
|
| 273 |
+
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
| 274 |
+
indicating whether the graph was truncated due to max_nodes limit
|
| 275 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
graph = await self._get_graph()
|
| 277 |
|
| 278 |
+
result = KnowledgeGraph()
|
|
|
|
|
|
|
| 279 |
|
| 280 |
# Handle special case for "*" label
|
| 281 |
if node_label == "*":
|
| 282 |
+
# Get degrees of all nodes
|
| 283 |
+
degrees = dict(graph.degree())
|
| 284 |
+
# Sort nodes by degree in descending order and take top max_nodes
|
| 285 |
+
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
| 286 |
+
|
| 287 |
+
# Check if graph is truncated
|
| 288 |
+
if len(sorted_nodes) > max_nodes:
|
| 289 |
+
result.is_truncated = True
|
| 290 |
+
logger.info(
|
| 291 |
+
f"Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
|
| 295 |
+
# Create subgraph with the highest degree nodes
|
| 296 |
+
subgraph = graph.subgraph(limited_nodes)
|
| 297 |
else:
|
| 298 |
+
# Check if node exists
|
| 299 |
+
if node_label not in graph:
|
| 300 |
+
logger.warning(f"Node {node_label} not found in the graph")
|
| 301 |
+
return KnowledgeGraph() # Return empty graph
|
| 302 |
+
|
| 303 |
+
# Use BFS to get nodes
|
| 304 |
+
bfs_nodes = []
|
| 305 |
+
visited = set()
|
| 306 |
+
queue = [(node_label, 0)] # (node, depth) tuple
|
| 307 |
+
|
| 308 |
+
# Breadth-first search
|
| 309 |
+
while queue and len(bfs_nodes) < max_nodes:
|
| 310 |
+
current, depth = queue.pop(0)
|
| 311 |
+
if current not in visited:
|
| 312 |
+
visited.add(current)
|
| 313 |
+
bfs_nodes.append(current)
|
| 314 |
+
|
| 315 |
+
# Only explore neighbors if we haven't reached max_depth
|
| 316 |
+
if depth < max_depth:
|
| 317 |
+
# Add neighbor nodes to queue with incremented depth
|
| 318 |
+
neighbors = list(graph.neighbors(current))
|
| 319 |
+
queue.extend(
|
| 320 |
+
[(n, depth + 1) for n in neighbors if n not in visited]
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Check if graph is truncated - if we still have nodes in the queue
|
| 324 |
+
# and we've reached max_nodes, then the graph is truncated
|
| 325 |
+
if queue and len(bfs_nodes) >= max_nodes:
|
| 326 |
+
result.is_truncated = True
|
| 327 |
+
logger.info(
|
| 328 |
+
f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Create subgraph with BFS discovered nodes
|
| 332 |
+
subgraph = graph.subgraph(bfs_nodes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
# Add nodes to result
|
| 335 |
+
seen_nodes = set()
|
| 336 |
+
seen_edges = set()
|
| 337 |
for node in subgraph.nodes():
|
| 338 |
if str(node) in seen_nodes:
|
| 339 |
continue
|
|
|
|
| 361 |
for edge in subgraph.edges():
|
| 362 |
source, target = edge
|
| 363 |
# Esure unique edge_id for undirect graph
|
| 364 |
+
if str(source) > str(target):
|
| 365 |
source, target = target, source
|
| 366 |
edge_id = f"{source}-{target}"
|
| 367 |
if edge_id in seen_edges:
|
|
|
|
| 417 |
return False # Return error
|
| 418 |
|
| 419 |
return True
|
| 420 |
+
|
| 421 |
+
async def drop(self) -> dict[str, str]:
|
| 422 |
+
"""Drop all graph data from storage and clean up resources
|
| 423 |
+
|
| 424 |
+
This method will:
|
| 425 |
+
1. Remove the graph storage file if it exists
|
| 426 |
+
2. Reset the graph to an empty state
|
| 427 |
+
3. Update flags to notify other processes
|
| 428 |
+
4. Changes is persisted to disk immediately
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
dict[str, str]: Operation status and message
|
| 432 |
+
- On success: {"status": "success", "message": "data dropped"}
|
| 433 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
| 434 |
+
"""
|
| 435 |
+
try:
|
| 436 |
+
async with self._storage_lock:
|
| 437 |
+
# delete _client_file_name
|
| 438 |
+
if os.path.exists(self._graphml_xml_file):
|
| 439 |
+
os.remove(self._graphml_xml_file)
|
| 440 |
+
self._graph = nx.Graph()
|
| 441 |
+
# Notify other processes that data has been updated
|
| 442 |
+
await set_all_update_flags(self.namespace)
|
| 443 |
+
# Reset own update flag to avoid self-reloading
|
| 444 |
+
self.storage_updated.value = False
|
| 445 |
+
logger.info(
|
| 446 |
+
f"Process {os.getpid()} drop graph {self.namespace} (file:{self._graphml_xml_file})"
|
| 447 |
+
)
|
| 448 |
+
return {"status": "success", "message": "data dropped"}
|
| 449 |
+
except Exception as e:
|
| 450 |
+
logger.error(f"Error dropping graph {self.namespace}: {e}")
|
| 451 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/oracle_impl.py
DELETED
|
@@ -1,1346 +0,0 @@
|
|
| 1 |
-
import array
|
| 2 |
-
import asyncio
|
| 3 |
-
|
| 4 |
-
# import html
|
| 5 |
-
import os
|
| 6 |
-
from dataclasses import dataclass, field
|
| 7 |
-
from typing import Any, Union, final
|
| 8 |
-
import numpy as np
|
| 9 |
-
import configparser
|
| 10 |
-
|
| 11 |
-
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
| 12 |
-
|
| 13 |
-
from ..base import (
|
| 14 |
-
BaseGraphStorage,
|
| 15 |
-
BaseKVStorage,
|
| 16 |
-
BaseVectorStorage,
|
| 17 |
-
)
|
| 18 |
-
from ..namespace import NameSpace, is_namespace
|
| 19 |
-
from ..utils import logger
|
| 20 |
-
|
| 21 |
-
import pipmaster as pm
|
| 22 |
-
|
| 23 |
-
if not pm.is_installed("graspologic"):
|
| 24 |
-
pm.install("graspologic")
|
| 25 |
-
|
| 26 |
-
if not pm.is_installed("oracledb"):
|
| 27 |
-
pm.install("oracledb")
|
| 28 |
-
|
| 29 |
-
from graspologic import embed
|
| 30 |
-
import oracledb
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class OracleDB:
|
| 34 |
-
def __init__(self, config, **kwargs):
|
| 35 |
-
self.host = config.get("host", None)
|
| 36 |
-
self.port = config.get("port", None)
|
| 37 |
-
self.user = config.get("user", None)
|
| 38 |
-
self.password = config.get("password", None)
|
| 39 |
-
self.dsn = config.get("dsn", None)
|
| 40 |
-
self.config_dir = config.get("config_dir", None)
|
| 41 |
-
self.wallet_location = config.get("wallet_location", None)
|
| 42 |
-
self.wallet_password = config.get("wallet_password", None)
|
| 43 |
-
self.workspace = config.get("workspace", None)
|
| 44 |
-
self.max = 12
|
| 45 |
-
self.increment = 1
|
| 46 |
-
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
|
| 47 |
-
if self.user is None or self.password is None:
|
| 48 |
-
raise ValueError("Missing database user or password")
|
| 49 |
-
|
| 50 |
-
try:
|
| 51 |
-
oracledb.defaults.fetch_lobs = False
|
| 52 |
-
|
| 53 |
-
self.pool = oracledb.create_pool_async(
|
| 54 |
-
user=self.user,
|
| 55 |
-
password=self.password,
|
| 56 |
-
dsn=self.dsn,
|
| 57 |
-
config_dir=self.config_dir,
|
| 58 |
-
wallet_location=self.wallet_location,
|
| 59 |
-
wallet_password=self.wallet_password,
|
| 60 |
-
min=1,
|
| 61 |
-
max=self.max,
|
| 62 |
-
increment=self.increment,
|
| 63 |
-
)
|
| 64 |
-
logger.info(f"Connected to Oracle database at {self.dsn}")
|
| 65 |
-
except Exception as e:
|
| 66 |
-
logger.error(f"Failed to connect to Oracle database at {self.dsn}")
|
| 67 |
-
logger.error(f"Oracle database error: {e}")
|
| 68 |
-
raise
|
| 69 |
-
|
| 70 |
-
def numpy_converter_in(self, value):
|
| 71 |
-
"""Convert numpy array to array.array"""
|
| 72 |
-
if value.dtype == np.float64:
|
| 73 |
-
dtype = "d"
|
| 74 |
-
elif value.dtype == np.float32:
|
| 75 |
-
dtype = "f"
|
| 76 |
-
else:
|
| 77 |
-
dtype = "b"
|
| 78 |
-
return array.array(dtype, value)
|
| 79 |
-
|
| 80 |
-
def input_type_handler(self, cursor, value, arraysize):
|
| 81 |
-
"""Set the type handler for the input data"""
|
| 82 |
-
if isinstance(value, np.ndarray):
|
| 83 |
-
return cursor.var(
|
| 84 |
-
oracledb.DB_TYPE_VECTOR,
|
| 85 |
-
arraysize=arraysize,
|
| 86 |
-
inconverter=self.numpy_converter_in,
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
def numpy_converter_out(self, value):
|
| 90 |
-
"""Convert array.array to numpy array"""
|
| 91 |
-
if value.typecode == "b":
|
| 92 |
-
dtype = np.int8
|
| 93 |
-
elif value.typecode == "f":
|
| 94 |
-
dtype = np.float32
|
| 95 |
-
else:
|
| 96 |
-
dtype = np.float64
|
| 97 |
-
return np.array(value, copy=False, dtype=dtype)
|
| 98 |
-
|
| 99 |
-
def output_type_handler(self, cursor, metadata):
|
| 100 |
-
"""Set the type handler for the output data"""
|
| 101 |
-
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
|
| 102 |
-
return cursor.var(
|
| 103 |
-
metadata.type_code,
|
| 104 |
-
arraysize=cursor.arraysize,
|
| 105 |
-
outconverter=self.numpy_converter_out,
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
async def check_tables(self):
|
| 109 |
-
for k, v in TABLES.items():
|
| 110 |
-
try:
|
| 111 |
-
if k.lower() == "lightrag_graph":
|
| 112 |
-
await self.query(
|
| 113 |
-
"SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
|
| 114 |
-
)
|
| 115 |
-
else:
|
| 116 |
-
await self.query(f"SELECT 1 FROM {k}")
|
| 117 |
-
except Exception as e:
|
| 118 |
-
logger.error(f"Failed to check table {k} in Oracle database")
|
| 119 |
-
logger.error(f"Oracle database error: {e}")
|
| 120 |
-
try:
|
| 121 |
-
# print(v["ddl"])
|
| 122 |
-
await self.execute(v["ddl"])
|
| 123 |
-
logger.info(f"Created table {k} in Oracle database")
|
| 124 |
-
except Exception as e:
|
| 125 |
-
logger.error(f"Failed to create table {k} in Oracle database")
|
| 126 |
-
logger.error(f"Oracle database error: {e}")
|
| 127 |
-
|
| 128 |
-
logger.info("Finished check all tables in Oracle database")
|
| 129 |
-
|
| 130 |
-
async def query(
|
| 131 |
-
self, sql: str, params: dict = None, multirows: bool = False
|
| 132 |
-
) -> Union[dict, None]:
|
| 133 |
-
async with self.pool.acquire() as connection:
|
| 134 |
-
connection.inputtypehandler = self.input_type_handler
|
| 135 |
-
connection.outputtypehandler = self.output_type_handler
|
| 136 |
-
with connection.cursor() as cursor:
|
| 137 |
-
try:
|
| 138 |
-
await cursor.execute(sql, params)
|
| 139 |
-
except Exception as e:
|
| 140 |
-
logger.error(f"Oracle database error: {e}")
|
| 141 |
-
raise
|
| 142 |
-
columns = [column[0].lower() for column in cursor.description]
|
| 143 |
-
if multirows:
|
| 144 |
-
rows = await cursor.fetchall()
|
| 145 |
-
if rows:
|
| 146 |
-
data = [dict(zip(columns, row)) for row in rows]
|
| 147 |
-
else:
|
| 148 |
-
data = []
|
| 149 |
-
else:
|
| 150 |
-
row = await cursor.fetchone()
|
| 151 |
-
if row:
|
| 152 |
-
data = dict(zip(columns, row))
|
| 153 |
-
else:
|
| 154 |
-
data = None
|
| 155 |
-
return data
|
| 156 |
-
|
| 157 |
-
async def execute(self, sql: str, data: Union[list, dict] = None):
|
| 158 |
-
# logger.info("go into OracleDB execute method")
|
| 159 |
-
try:
|
| 160 |
-
async with self.pool.acquire() as connection:
|
| 161 |
-
connection.inputtypehandler = self.input_type_handler
|
| 162 |
-
connection.outputtypehandler = self.output_type_handler
|
| 163 |
-
with connection.cursor() as cursor:
|
| 164 |
-
if data is None:
|
| 165 |
-
await cursor.execute(sql)
|
| 166 |
-
else:
|
| 167 |
-
await cursor.execute(sql, data)
|
| 168 |
-
await connection.commit()
|
| 169 |
-
except Exception as e:
|
| 170 |
-
logger.error(f"Oracle database error: {e}")
|
| 171 |
-
raise
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
class ClientManager:
|
| 175 |
-
_instances: dict[str, Any] = {"db": None, "ref_count": 0}
|
| 176 |
-
_lock = asyncio.Lock()
|
| 177 |
-
|
| 178 |
-
@staticmethod
|
| 179 |
-
def get_config() -> dict[str, Any]:
|
| 180 |
-
config = configparser.ConfigParser()
|
| 181 |
-
config.read("config.ini", "utf-8")
|
| 182 |
-
|
| 183 |
-
return {
|
| 184 |
-
"user": os.environ.get(
|
| 185 |
-
"ORACLE_USER",
|
| 186 |
-
config.get("oracle", "user", fallback=None),
|
| 187 |
-
),
|
| 188 |
-
"password": os.environ.get(
|
| 189 |
-
"ORACLE_PASSWORD",
|
| 190 |
-
config.get("oracle", "password", fallback=None),
|
| 191 |
-
),
|
| 192 |
-
"dsn": os.environ.get(
|
| 193 |
-
"ORACLE_DSN",
|
| 194 |
-
config.get("oracle", "dsn", fallback=None),
|
| 195 |
-
),
|
| 196 |
-
"config_dir": os.environ.get(
|
| 197 |
-
"ORACLE_CONFIG_DIR",
|
| 198 |
-
config.get("oracle", "config_dir", fallback=None),
|
| 199 |
-
),
|
| 200 |
-
"wallet_location": os.environ.get(
|
| 201 |
-
"ORACLE_WALLET_LOCATION",
|
| 202 |
-
config.get("oracle", "wallet_location", fallback=None),
|
| 203 |
-
),
|
| 204 |
-
"wallet_password": os.environ.get(
|
| 205 |
-
"ORACLE_WALLET_PASSWORD",
|
| 206 |
-
config.get("oracle", "wallet_password", fallback=None),
|
| 207 |
-
),
|
| 208 |
-
"workspace": os.environ.get(
|
| 209 |
-
"ORACLE_WORKSPACE",
|
| 210 |
-
config.get("oracle", "workspace", fallback="default"),
|
| 211 |
-
),
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
@classmethod
|
| 215 |
-
async def get_client(cls) -> OracleDB:
|
| 216 |
-
async with cls._lock:
|
| 217 |
-
if cls._instances["db"] is None:
|
| 218 |
-
config = ClientManager.get_config()
|
| 219 |
-
db = OracleDB(config)
|
| 220 |
-
await db.check_tables()
|
| 221 |
-
cls._instances["db"] = db
|
| 222 |
-
cls._instances["ref_count"] = 0
|
| 223 |
-
cls._instances["ref_count"] += 1
|
| 224 |
-
return cls._instances["db"]
|
| 225 |
-
|
| 226 |
-
@classmethod
|
| 227 |
-
async def release_client(cls, db: OracleDB):
|
| 228 |
-
async with cls._lock:
|
| 229 |
-
if db is not None:
|
| 230 |
-
if db is cls._instances["db"]:
|
| 231 |
-
cls._instances["ref_count"] -= 1
|
| 232 |
-
if cls._instances["ref_count"] == 0:
|
| 233 |
-
await db.pool.close()
|
| 234 |
-
logger.info("Closed OracleDB database connection pool")
|
| 235 |
-
cls._instances["db"] = None
|
| 236 |
-
else:
|
| 237 |
-
await db.pool.close()
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
@final
|
| 241 |
-
@dataclass
|
| 242 |
-
class OracleKVStorage(BaseKVStorage):
|
| 243 |
-
db: OracleDB = field(default=None)
|
| 244 |
-
meta_fields = None
|
| 245 |
-
|
| 246 |
-
def __post_init__(self):
|
| 247 |
-
self._data = {}
|
| 248 |
-
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
| 249 |
-
|
| 250 |
-
async def initialize(self):
|
| 251 |
-
if self.db is None:
|
| 252 |
-
self.db = await ClientManager.get_client()
|
| 253 |
-
|
| 254 |
-
async def finalize(self):
|
| 255 |
-
if self.db is not None:
|
| 256 |
-
await ClientManager.release_client(self.db)
|
| 257 |
-
self.db = None
|
| 258 |
-
|
| 259 |
-
################ QUERY METHODS ################
|
| 260 |
-
|
| 261 |
-
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
| 262 |
-
"""Get doc_full data based on id."""
|
| 263 |
-
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
| 264 |
-
params = {"workspace": self.db.workspace, "id": id}
|
| 265 |
-
# print("get_by_id:"+SQL)
|
| 266 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
| 267 |
-
array_res = await self.db.query(SQL, params, multirows=True)
|
| 268 |
-
res = {}
|
| 269 |
-
for row in array_res:
|
| 270 |
-
res[row["id"]] = row
|
| 271 |
-
if res:
|
| 272 |
-
return res
|
| 273 |
-
else:
|
| 274 |
-
return None
|
| 275 |
-
else:
|
| 276 |
-
return await self.db.query(SQL, params)
|
| 277 |
-
|
| 278 |
-
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
| 279 |
-
"""Specifically for llm_response_cache."""
|
| 280 |
-
SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
|
| 281 |
-
params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
|
| 282 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
| 283 |
-
array_res = await self.db.query(SQL, params, multirows=True)
|
| 284 |
-
res = {}
|
| 285 |
-
for row in array_res:
|
| 286 |
-
res[row["id"]] = row
|
| 287 |
-
return res
|
| 288 |
-
else:
|
| 289 |
-
return None
|
| 290 |
-
|
| 291 |
-
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
| 292 |
-
"""Get doc_chunks data based on id"""
|
| 293 |
-
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
| 294 |
-
ids=",".join([f"'{id}'" for id in ids])
|
| 295 |
-
)
|
| 296 |
-
params = {"workspace": self.db.workspace}
|
| 297 |
-
# print("get_by_ids:"+SQL)
|
| 298 |
-
res = await self.db.query(SQL, params, multirows=True)
|
| 299 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
| 300 |
-
modes = set()
|
| 301 |
-
dict_res: dict[str, dict] = {}
|
| 302 |
-
for row in res:
|
| 303 |
-
modes.add(row["mode"])
|
| 304 |
-
for mode in modes:
|
| 305 |
-
if mode not in dict_res:
|
| 306 |
-
dict_res[mode] = {}
|
| 307 |
-
for row in res:
|
| 308 |
-
dict_res[row["mode"]][row["id"]] = row
|
| 309 |
-
res = [{k: v} for k, v in dict_res.items()]
|
| 310 |
-
return res
|
| 311 |
-
|
| 312 |
-
async def filter_keys(self, keys: set[str]) -> set[str]:
|
| 313 |
-
"""Return keys that don't exist in storage"""
|
| 314 |
-
SQL = SQL_TEMPLATES["filter_keys"].format(
|
| 315 |
-
table_name=namespace_to_table_name(self.namespace),
|
| 316 |
-
ids=",".join([f"'{id}'" for id in keys]),
|
| 317 |
-
)
|
| 318 |
-
params = {"workspace": self.db.workspace}
|
| 319 |
-
res = await self.db.query(SQL, params, multirows=True)
|
| 320 |
-
if res:
|
| 321 |
-
exist_keys = [key["id"] for key in res]
|
| 322 |
-
data = set([s for s in keys if s not in exist_keys])
|
| 323 |
-
return data
|
| 324 |
-
else:
|
| 325 |
-
return set(keys)
|
| 326 |
-
|
| 327 |
-
################ INSERT METHODS ################
|
| 328 |
-
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 329 |
-
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
| 330 |
-
if not data:
|
| 331 |
-
return
|
| 332 |
-
|
| 333 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
|
| 334 |
-
list_data = [
|
| 335 |
-
{
|
| 336 |
-
"id": k,
|
| 337 |
-
**{k1: v1 for k1, v1 in v.items()},
|
| 338 |
-
}
|
| 339 |
-
for k, v in data.items()
|
| 340 |
-
]
|
| 341 |
-
contents = [v["content"] for v in data.values()]
|
| 342 |
-
batches = [
|
| 343 |
-
contents[i : i + self._max_batch_size]
|
| 344 |
-
for i in range(0, len(contents), self._max_batch_size)
|
| 345 |
-
]
|
| 346 |
-
embeddings_list = await asyncio.gather(
|
| 347 |
-
*[self.embedding_func(batch) for batch in batches]
|
| 348 |
-
)
|
| 349 |
-
embeddings = np.concatenate(embeddings_list)
|
| 350 |
-
for i, d in enumerate(list_data):
|
| 351 |
-
d["__vector__"] = embeddings[i]
|
| 352 |
-
|
| 353 |
-
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
| 354 |
-
for item in list_data:
|
| 355 |
-
_data = {
|
| 356 |
-
"id": item["id"],
|
| 357 |
-
"content": item["content"],
|
| 358 |
-
"workspace": self.db.workspace,
|
| 359 |
-
"tokens": item["tokens"],
|
| 360 |
-
"chunk_order_index": item["chunk_order_index"],
|
| 361 |
-
"full_doc_id": item["full_doc_id"],
|
| 362 |
-
"content_vector": item["__vector__"],
|
| 363 |
-
"status": item["status"],
|
| 364 |
-
}
|
| 365 |
-
await self.db.execute(merge_sql, _data)
|
| 366 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
|
| 367 |
-
for k, v in data.items():
|
| 368 |
-
# values.clear()
|
| 369 |
-
merge_sql = SQL_TEMPLATES["merge_doc_full"]
|
| 370 |
-
_data = {
|
| 371 |
-
"id": k,
|
| 372 |
-
"content": v["content"],
|
| 373 |
-
"workspace": self.db.workspace,
|
| 374 |
-
}
|
| 375 |
-
await self.db.execute(merge_sql, _data)
|
| 376 |
-
|
| 377 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
| 378 |
-
for mode, items in data.items():
|
| 379 |
-
for k, v in items.items():
|
| 380 |
-
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
| 381 |
-
_data = {
|
| 382 |
-
"workspace": self.db.workspace,
|
| 383 |
-
"id": k,
|
| 384 |
-
"original_prompt": v["original_prompt"],
|
| 385 |
-
"return_value": v["return"],
|
| 386 |
-
"cache_mode": mode,
|
| 387 |
-
}
|
| 388 |
-
|
| 389 |
-
await self.db.execute(upsert_sql, _data)
|
| 390 |
-
|
| 391 |
-
async def index_done_callback(self) -> None:
|
| 392 |
-
# Oracle handles persistence automatically
|
| 393 |
-
pass
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
@final
|
| 397 |
-
@dataclass
|
| 398 |
-
class OracleVectorDBStorage(BaseVectorStorage):
|
| 399 |
-
db: OracleDB | None = field(default=None)
|
| 400 |
-
|
| 401 |
-
def __post_init__(self):
|
| 402 |
-
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
| 403 |
-
cosine_threshold = config.get("cosine_better_than_threshold")
|
| 404 |
-
if cosine_threshold is None:
|
| 405 |
-
raise ValueError(
|
| 406 |
-
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
| 407 |
-
)
|
| 408 |
-
self.cosine_better_than_threshold = cosine_threshold
|
| 409 |
-
|
| 410 |
-
async def initialize(self):
|
| 411 |
-
if self.db is None:
|
| 412 |
-
self.db = await ClientManager.get_client()
|
| 413 |
-
|
| 414 |
-
async def finalize(self):
|
| 415 |
-
if self.db is not None:
|
| 416 |
-
await ClientManager.release_client(self.db)
|
| 417 |
-
self.db = None
|
| 418 |
-
|
| 419 |
-
#################### query method ###############
|
| 420 |
-
async def query(
|
| 421 |
-
self, query: str, top_k: int, ids: list[str] | None = None
|
| 422 |
-
) -> list[dict[str, Any]]:
|
| 423 |
-
embeddings = await self.embedding_func([query])
|
| 424 |
-
embedding = embeddings[0]
|
| 425 |
-
# 转换精度
|
| 426 |
-
dtype = str(embedding.dtype).upper()
|
| 427 |
-
dimension = embedding.shape[0]
|
| 428 |
-
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
|
| 429 |
-
|
| 430 |
-
SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
|
| 431 |
-
params = {
|
| 432 |
-
"embedding_string": embedding_string,
|
| 433 |
-
"workspace": self.db.workspace,
|
| 434 |
-
"top_k": top_k,
|
| 435 |
-
"better_than_threshold": self.cosine_better_than_threshold,
|
| 436 |
-
}
|
| 437 |
-
results = await self.db.query(SQL, params=params, multirows=True)
|
| 438 |
-
return results
|
| 439 |
-
|
| 440 |
-
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 441 |
-
raise NotImplementedError
|
| 442 |
-
|
| 443 |
-
async def index_done_callback(self) -> None:
|
| 444 |
-
# Oracles handles persistence automatically
|
| 445 |
-
pass
|
| 446 |
-
|
| 447 |
-
async def delete(self, ids: list[str]) -> None:
|
| 448 |
-
"""Delete vectors with specified IDs
|
| 449 |
-
|
| 450 |
-
Args:
|
| 451 |
-
ids: List of vector IDs to be deleted
|
| 452 |
-
"""
|
| 453 |
-
if not ids:
|
| 454 |
-
return
|
| 455 |
-
|
| 456 |
-
try:
|
| 457 |
-
SQL = SQL_TEMPLATES["delete_vectors"].format(
|
| 458 |
-
ids=",".join([f"'{id}'" for id in ids])
|
| 459 |
-
)
|
| 460 |
-
params = {"workspace": self.db.workspace}
|
| 461 |
-
await self.db.execute(SQL, params)
|
| 462 |
-
logger.info(
|
| 463 |
-
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
| 464 |
-
)
|
| 465 |
-
except Exception as e:
|
| 466 |
-
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
| 467 |
-
raise
|
| 468 |
-
|
| 469 |
-
async def delete_entity(self, entity_name: str) -> None:
|
| 470 |
-
"""Delete entity by name
|
| 471 |
-
|
| 472 |
-
Args:
|
| 473 |
-
entity_name: Name of the entity to delete
|
| 474 |
-
"""
|
| 475 |
-
try:
|
| 476 |
-
SQL = SQL_TEMPLATES["delete_entity"]
|
| 477 |
-
params = {"workspace": self.db.workspace, "entity_name": entity_name}
|
| 478 |
-
await self.db.execute(SQL, params)
|
| 479 |
-
logger.info(f"Successfully deleted entity {entity_name}")
|
| 480 |
-
except Exception as e:
|
| 481 |
-
logger.error(f"Error deleting entity {entity_name}: {e}")
|
| 482 |
-
raise
|
| 483 |
-
|
| 484 |
-
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 485 |
-
"""Delete all relations connected to an entity
|
| 486 |
-
|
| 487 |
-
Args:
|
| 488 |
-
entity_name: Name of the entity whose relations should be deleted
|
| 489 |
-
"""
|
| 490 |
-
try:
|
| 491 |
-
SQL = SQL_TEMPLATES["delete_entity_relations"]
|
| 492 |
-
params = {"workspace": self.db.workspace, "entity_name": entity_name}
|
| 493 |
-
await self.db.execute(SQL, params)
|
| 494 |
-
logger.info(f"Successfully deleted relations for entity {entity_name}")
|
| 495 |
-
except Exception as e:
|
| 496 |
-
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
|
| 497 |
-
raise
|
| 498 |
-
|
| 499 |
-
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
| 500 |
-
"""Search for records with IDs starting with a specific prefix.
|
| 501 |
-
|
| 502 |
-
Args:
|
| 503 |
-
prefix: The prefix to search for in record IDs
|
| 504 |
-
|
| 505 |
-
Returns:
|
| 506 |
-
List of records with matching ID prefixes
|
| 507 |
-
"""
|
| 508 |
-
try:
|
| 509 |
-
# Determine the appropriate table based on namespace
|
| 510 |
-
table_name = namespace_to_table_name(self.namespace)
|
| 511 |
-
|
| 512 |
-
# Create SQL query to find records with IDs starting with prefix
|
| 513 |
-
search_sql = f"""
|
| 514 |
-
SELECT * FROM {table_name}
|
| 515 |
-
WHERE workspace = :workspace
|
| 516 |
-
AND id LIKE :prefix_pattern
|
| 517 |
-
ORDER BY id
|
| 518 |
-
"""
|
| 519 |
-
|
| 520 |
-
params = {"workspace": self.db.workspace, "prefix_pattern": f"{prefix}%"}
|
| 521 |
-
|
| 522 |
-
# Execute query and get results
|
| 523 |
-
results = await self.db.query(search_sql, params, multirows=True)
|
| 524 |
-
|
| 525 |
-
logger.debug(
|
| 526 |
-
f"Found {len(results) if results else 0} records with prefix '{prefix}'"
|
| 527 |
-
)
|
| 528 |
-
return results or []
|
| 529 |
-
|
| 530 |
-
except Exception as e:
|
| 531 |
-
logger.error(f"Error searching records with prefix '{prefix}': {e}")
|
| 532 |
-
return []
|
| 533 |
-
|
| 534 |
-
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
| 535 |
-
"""Get vector data by its ID
|
| 536 |
-
|
| 537 |
-
Args:
|
| 538 |
-
id: The unique identifier of the vector
|
| 539 |
-
|
| 540 |
-
Returns:
|
| 541 |
-
The vector data if found, or None if not found
|
| 542 |
-
"""
|
| 543 |
-
try:
|
| 544 |
-
# Determine the table name based on namespace
|
| 545 |
-
table_name = namespace_to_table_name(self.namespace)
|
| 546 |
-
if not table_name:
|
| 547 |
-
logger.error(f"Unknown namespace for ID lookup: {self.namespace}")
|
| 548 |
-
return None
|
| 549 |
-
|
| 550 |
-
# Create the appropriate ID field name based on namespace
|
| 551 |
-
id_field = "entity_id" if "NODES" in table_name else "relation_id"
|
| 552 |
-
if "CHUNKS" in table_name:
|
| 553 |
-
id_field = "chunk_id"
|
| 554 |
-
|
| 555 |
-
# Prepare and execute the query
|
| 556 |
-
query = f"""
|
| 557 |
-
SELECT * FROM {table_name}
|
| 558 |
-
WHERE {id_field} = :id AND workspace = :workspace
|
| 559 |
-
"""
|
| 560 |
-
params = {"id": id, "workspace": self.db.workspace}
|
| 561 |
-
|
| 562 |
-
result = await self.db.query(query, params)
|
| 563 |
-
return result
|
| 564 |
-
except Exception as e:
|
| 565 |
-
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
| 566 |
-
return None
|
| 567 |
-
|
| 568 |
-
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
| 569 |
-
"""Get multiple vector data by their IDs
|
| 570 |
-
|
| 571 |
-
Args:
|
| 572 |
-
ids: List of unique identifiers
|
| 573 |
-
|
| 574 |
-
Returns:
|
| 575 |
-
List of vector data objects that were found
|
| 576 |
-
"""
|
| 577 |
-
if not ids:
|
| 578 |
-
return []
|
| 579 |
-
|
| 580 |
-
try:
|
| 581 |
-
# Determine the table name based on namespace
|
| 582 |
-
table_name = namespace_to_table_name(self.namespace)
|
| 583 |
-
if not table_name:
|
| 584 |
-
logger.error(f"Unknown namespace for IDs lookup: {self.namespace}")
|
| 585 |
-
return []
|
| 586 |
-
|
| 587 |
-
# Create the appropriate ID field name based on namespace
|
| 588 |
-
id_field = "entity_id" if "NODES" in table_name else "relation_id"
|
| 589 |
-
if "CHUNKS" in table_name:
|
| 590 |
-
id_field = "chunk_id"
|
| 591 |
-
|
| 592 |
-
# Format the list of IDs for SQL IN clause
|
| 593 |
-
ids_list = ", ".join([f"'{id}'" for id in ids])
|
| 594 |
-
|
| 595 |
-
# Prepare and execute the query
|
| 596 |
-
query = f"""
|
| 597 |
-
SELECT * FROM {table_name}
|
| 598 |
-
WHERE {id_field} IN ({ids_list}) AND workspace = :workspace
|
| 599 |
-
"""
|
| 600 |
-
params = {"workspace": self.db.workspace}
|
| 601 |
-
|
| 602 |
-
results = await self.db.query(query, params, multirows=True)
|
| 603 |
-
return results or []
|
| 604 |
-
except Exception as e:
|
| 605 |
-
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
| 606 |
-
return []
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
@final
|
| 610 |
-
@dataclass
|
| 611 |
-
class OracleGraphStorage(BaseGraphStorage):
|
| 612 |
-
db: OracleDB = field(default=None)
|
| 613 |
-
|
| 614 |
-
def __post_init__(self):
|
| 615 |
-
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
| 616 |
-
|
| 617 |
-
async def initialize(self):
|
| 618 |
-
if self.db is None:
|
| 619 |
-
self.db = await ClientManager.get_client()
|
| 620 |
-
|
| 621 |
-
async def finalize(self):
|
| 622 |
-
if self.db is not None:
|
| 623 |
-
await ClientManager.release_client(self.db)
|
| 624 |
-
self.db = None
|
| 625 |
-
|
| 626 |
-
#################### insert method ################
|
| 627 |
-
|
| 628 |
-
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
| 629 |
-
entity_name = node_id
|
| 630 |
-
entity_type = node_data["entity_type"]
|
| 631 |
-
description = node_data["description"]
|
| 632 |
-
source_id = node_data["source_id"]
|
| 633 |
-
logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
|
| 634 |
-
|
| 635 |
-
content = entity_name + description
|
| 636 |
-
contents = [content]
|
| 637 |
-
batches = [
|
| 638 |
-
contents[i : i + self._max_batch_size]
|
| 639 |
-
for i in range(0, len(contents), self._max_batch_size)
|
| 640 |
-
]
|
| 641 |
-
embeddings_list = await asyncio.gather(
|
| 642 |
-
*[self.embedding_func(batch) for batch in batches]
|
| 643 |
-
)
|
| 644 |
-
embeddings = np.concatenate(embeddings_list)
|
| 645 |
-
content_vector = embeddings[0]
|
| 646 |
-
merge_sql = SQL_TEMPLATES["merge_node"]
|
| 647 |
-
data = {
|
| 648 |
-
"workspace": self.db.workspace,
|
| 649 |
-
"name": entity_name,
|
| 650 |
-
"entity_type": entity_type,
|
| 651 |
-
"description": description,
|
| 652 |
-
"source_chunk_id": source_id,
|
| 653 |
-
"content": content,
|
| 654 |
-
"content_vector": content_vector,
|
| 655 |
-
}
|
| 656 |
-
await self.db.execute(merge_sql, data)
|
| 657 |
-
# self._graph.add_node(node_id, **node_data)
|
| 658 |
-
|
| 659 |
-
async def upsert_edge(
|
| 660 |
-
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 661 |
-
) -> None:
|
| 662 |
-
"""插入或更新边"""
|
| 663 |
-
# print("go into upsert edge method")
|
| 664 |
-
source_name = source_node_id
|
| 665 |
-
target_name = target_node_id
|
| 666 |
-
weight = edge_data["weight"]
|
| 667 |
-
keywords = edge_data["keywords"]
|
| 668 |
-
description = edge_data["description"]
|
| 669 |
-
source_chunk_id = edge_data["source_id"]
|
| 670 |
-
logger.debug(
|
| 671 |
-
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
|
| 672 |
-
)
|
| 673 |
-
|
| 674 |
-
content = keywords + source_name + target_name + description
|
| 675 |
-
contents = [content]
|
| 676 |
-
batches = [
|
| 677 |
-
contents[i : i + self._max_batch_size]
|
| 678 |
-
for i in range(0, len(contents), self._max_batch_size)
|
| 679 |
-
]
|
| 680 |
-
embeddings_list = await asyncio.gather(
|
| 681 |
-
*[self.embedding_func(batch) for batch in batches]
|
| 682 |
-
)
|
| 683 |
-
embeddings = np.concatenate(embeddings_list)
|
| 684 |
-
content_vector = embeddings[0]
|
| 685 |
-
merge_sql = SQL_TEMPLATES["merge_edge"]
|
| 686 |
-
data = {
|
| 687 |
-
"workspace": self.db.workspace,
|
| 688 |
-
"source_name": source_name,
|
| 689 |
-
"target_name": target_name,
|
| 690 |
-
"weight": weight,
|
| 691 |
-
"keywords": keywords,
|
| 692 |
-
"description": description,
|
| 693 |
-
"source_chunk_id": source_chunk_id,
|
| 694 |
-
"content": content,
|
| 695 |
-
"content_vector": content_vector,
|
| 696 |
-
}
|
| 697 |
-
# print(merge_sql)
|
| 698 |
-
await self.db.execute(merge_sql, data)
|
| 699 |
-
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
| 700 |
-
|
| 701 |
-
async def embed_nodes(
|
| 702 |
-
self, algorithm: str
|
| 703 |
-
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
| 704 |
-
if algorithm not in self._node_embed_algorithms:
|
| 705 |
-
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
| 706 |
-
return await self._node_embed_algorithms[algorithm]()
|
| 707 |
-
|
| 708 |
-
async def _node2vec_embed(self):
|
| 709 |
-
"""为节点生成向量"""
|
| 710 |
-
embeddings, nodes = embed.node2vec_embed(
|
| 711 |
-
self._graph,
|
| 712 |
-
**self.config["node2vec_params"],
|
| 713 |
-
)
|
| 714 |
-
|
| 715 |
-
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
| 716 |
-
return embeddings, nodes_ids
|
| 717 |
-
|
| 718 |
-
async def index_done_callback(self) -> None:
|
| 719 |
-
# Oracles handles persistence automatically
|
| 720 |
-
pass
|
| 721 |
-
|
| 722 |
-
#################### query method #################
|
| 723 |
-
async def has_node(self, node_id: str) -> bool:
|
| 724 |
-
"""根据节点id检查节点是否存在"""
|
| 725 |
-
SQL = SQL_TEMPLATES["has_node"]
|
| 726 |
-
params = {"workspace": self.db.workspace, "node_id": node_id}
|
| 727 |
-
res = await self.db.query(SQL, params)
|
| 728 |
-
if res:
|
| 729 |
-
# print("Node exist!",res)
|
| 730 |
-
return True
|
| 731 |
-
else:
|
| 732 |
-
# print("Node not exist!")
|
| 733 |
-
return False
|
| 734 |
-
|
| 735 |
-
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 736 |
-
SQL = SQL_TEMPLATES["has_edge"]
|
| 737 |
-
params = {
|
| 738 |
-
"workspace": self.db.workspace,
|
| 739 |
-
"source_node_id": source_node_id,
|
| 740 |
-
"target_node_id": target_node_id,
|
| 741 |
-
}
|
| 742 |
-
res = await self.db.query(SQL, params)
|
| 743 |
-
if res:
|
| 744 |
-
# print("Edge exist!",res)
|
| 745 |
-
return True
|
| 746 |
-
else:
|
| 747 |
-
# print("Edge not exist!")
|
| 748 |
-
return False
|
| 749 |
-
|
| 750 |
-
async def node_degree(self, node_id: str) -> int:
|
| 751 |
-
SQL = SQL_TEMPLATES["node_degree"]
|
| 752 |
-
params = {"workspace": self.db.workspace, "node_id": node_id}
|
| 753 |
-
res = await self.db.query(SQL, params)
|
| 754 |
-
if res:
|
| 755 |
-
return res["degree"]
|
| 756 |
-
else:
|
| 757 |
-
return 0
|
| 758 |
-
|
| 759 |
-
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 760 |
-
"""根据源和目标节点id获取边的度"""
|
| 761 |
-
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
| 762 |
-
return degree
|
| 763 |
-
|
| 764 |
-
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
| 765 |
-
"""根据节点id获取节点数据"""
|
| 766 |
-
SQL = SQL_TEMPLATES["get_node"]
|
| 767 |
-
params = {"workspace": self.db.workspace, "node_id": node_id}
|
| 768 |
-
res = await self.db.query(SQL, params)
|
| 769 |
-
if res:
|
| 770 |
-
return res
|
| 771 |
-
else:
|
| 772 |
-
return None
|
| 773 |
-
|
| 774 |
-
async def get_edge(
|
| 775 |
-
self, source_node_id: str, target_node_id: str
|
| 776 |
-
) -> dict[str, str] | None:
|
| 777 |
-
SQL = SQL_TEMPLATES["get_edge"]
|
| 778 |
-
params = {
|
| 779 |
-
"workspace": self.db.workspace,
|
| 780 |
-
"source_node_id": source_node_id,
|
| 781 |
-
"target_node_id": target_node_id,
|
| 782 |
-
}
|
| 783 |
-
res = await self.db.query(SQL, params)
|
| 784 |
-
if res:
|
| 785 |
-
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
| 786 |
-
return res
|
| 787 |
-
else:
|
| 788 |
-
# print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
|
| 789 |
-
return None
|
| 790 |
-
|
| 791 |
-
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
| 792 |
-
if await self.has_node(source_node_id):
|
| 793 |
-
SQL = SQL_TEMPLATES["get_node_edges"]
|
| 794 |
-
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
| 795 |
-
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
| 796 |
-
if res:
|
| 797 |
-
data = [(i["source_name"], i["target_name"]) for i in res]
|
| 798 |
-
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
| 799 |
-
return data
|
| 800 |
-
else:
|
| 801 |
-
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
| 802 |
-
return []
|
| 803 |
-
|
| 804 |
-
async def get_all_nodes(self, limit: int):
|
| 805 |
-
"""查询所有节点"""
|
| 806 |
-
SQL = SQL_TEMPLATES["get_all_nodes"]
|
| 807 |
-
params = {"workspace": self.db.workspace, "limit": str(limit)}
|
| 808 |
-
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
| 809 |
-
if res:
|
| 810 |
-
return res
|
| 811 |
-
|
| 812 |
-
async def get_all_edges(self, limit: int):
|
| 813 |
-
"""查询所有边"""
|
| 814 |
-
SQL = SQL_TEMPLATES["get_all_edges"]
|
| 815 |
-
params = {"workspace": self.db.workspace, "limit": str(limit)}
|
| 816 |
-
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
| 817 |
-
if res:
|
| 818 |
-
return res
|
| 819 |
-
|
| 820 |
-
async def get_statistics(self):
|
| 821 |
-
SQL = SQL_TEMPLATES["get_statistics"]
|
| 822 |
-
params = {"workspace": self.db.workspace}
|
| 823 |
-
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
| 824 |
-
if res:
|
| 825 |
-
return res
|
| 826 |
-
|
| 827 |
-
async def delete_node(self, node_id: str) -> None:
|
| 828 |
-
"""Delete a node from the graph
|
| 829 |
-
|
| 830 |
-
Args:
|
| 831 |
-
node_id: ID of the node to delete
|
| 832 |
-
"""
|
| 833 |
-
try:
|
| 834 |
-
# First delete all relations connected to this node
|
| 835 |
-
delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"]
|
| 836 |
-
params_relations = {"workspace": self.db.workspace, "entity_name": node_id}
|
| 837 |
-
await self.db.execute(delete_relations_sql, params_relations)
|
| 838 |
-
|
| 839 |
-
# Then delete the node itself
|
| 840 |
-
delete_node_sql = SQL_TEMPLATES["delete_entity"]
|
| 841 |
-
params_node = {"workspace": self.db.workspace, "entity_name": node_id}
|
| 842 |
-
await self.db.execute(delete_node_sql, params_node)
|
| 843 |
-
|
| 844 |
-
logger.info(
|
| 845 |
-
f"Successfully deleted node {node_id} and all its relationships"
|
| 846 |
-
)
|
| 847 |
-
except Exception as e:
|
| 848 |
-
logger.error(f"Error deleting node {node_id}: {e}")
|
| 849 |
-
raise
|
| 850 |
-
|
| 851 |
-
async def remove_nodes(self, nodes: list[str]) -> None:
|
| 852 |
-
"""Delete multiple nodes from the graph
|
| 853 |
-
|
| 854 |
-
Args:
|
| 855 |
-
nodes: List of node IDs to be deleted
|
| 856 |
-
"""
|
| 857 |
-
if not nodes:
|
| 858 |
-
return
|
| 859 |
-
|
| 860 |
-
try:
|
| 861 |
-
for node in nodes:
|
| 862 |
-
# For each node, first delete all its relationships
|
| 863 |
-
delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"]
|
| 864 |
-
params_relations = {"workspace": self.db.workspace, "entity_name": node}
|
| 865 |
-
await self.db.execute(delete_relations_sql, params_relations)
|
| 866 |
-
|
| 867 |
-
# Then delete the node itself
|
| 868 |
-
delete_node_sql = SQL_TEMPLATES["delete_entity"]
|
| 869 |
-
params_node = {"workspace": self.db.workspace, "entity_name": node}
|
| 870 |
-
await self.db.execute(delete_node_sql, params_node)
|
| 871 |
-
|
| 872 |
-
logger.info(
|
| 873 |
-
f"Successfully deleted {len(nodes)} nodes and their relationships"
|
| 874 |
-
)
|
| 875 |
-
except Exception as e:
|
| 876 |
-
logger.error(f"Error during batch node deletion: {e}")
|
| 877 |
-
raise
|
| 878 |
-
|
| 879 |
-
async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
|
| 880 |
-
"""Delete multiple edges from the graph
|
| 881 |
-
|
| 882 |
-
Args:
|
| 883 |
-
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
| 884 |
-
"""
|
| 885 |
-
if not edges:
|
| 886 |
-
return
|
| 887 |
-
|
| 888 |
-
try:
|
| 889 |
-
for source, target in edges:
|
| 890 |
-
# Check if the edge exists before attempting to delete
|
| 891 |
-
if await self.has_edge(source, target):
|
| 892 |
-
# Delete the edge using a SQL query that matches both source and target
|
| 893 |
-
delete_edge_sql = """
|
| 894 |
-
DELETE FROM LIGHTRAG_GRAPH_EDGES
|
| 895 |
-
WHERE workspace = :workspace
|
| 896 |
-
AND source_name = :source_name
|
| 897 |
-
AND target_name = :target_name
|
| 898 |
-
"""
|
| 899 |
-
params = {
|
| 900 |
-
"workspace": self.db.workspace,
|
| 901 |
-
"source_name": source,
|
| 902 |
-
"target_name": target,
|
| 903 |
-
}
|
| 904 |
-
await self.db.execute(delete_edge_sql, params)
|
| 905 |
-
|
| 906 |
-
logger.info(f"Successfully deleted {len(edges)} edges from the graph")
|
| 907 |
-
except Exception as e:
|
| 908 |
-
logger.error(f"Error during batch edge deletion: {e}")
|
| 909 |
-
raise
|
| 910 |
-
|
| 911 |
-
async def get_all_labels(self) -> list[str]:
|
| 912 |
-
"""Get all unique entity types (labels) in the graph
|
| 913 |
-
|
| 914 |
-
Returns:
|
| 915 |
-
List of unique entity types/labels
|
| 916 |
-
"""
|
| 917 |
-
try:
|
| 918 |
-
SQL = """
|
| 919 |
-
SELECT DISTINCT entity_type
|
| 920 |
-
FROM LIGHTRAG_GRAPH_NODES
|
| 921 |
-
WHERE workspace = :workspace
|
| 922 |
-
ORDER BY entity_type
|
| 923 |
-
"""
|
| 924 |
-
params = {"workspace": self.db.workspace}
|
| 925 |
-
results = await self.db.query(SQL, params, multirows=True)
|
| 926 |
-
|
| 927 |
-
if results:
|
| 928 |
-
labels = [row["entity_type"] for row in results]
|
| 929 |
-
return labels
|
| 930 |
-
else:
|
| 931 |
-
return []
|
| 932 |
-
except Exception as e:
|
| 933 |
-
logger.error(f"Error retrieving entity types: {e}")
|
| 934 |
-
return []
|
| 935 |
-
|
| 936 |
-
async def get_knowledge_graph(
|
| 937 |
-
self, node_label: str, max_depth: int = 5
|
| 938 |
-
) -> KnowledgeGraph:
|
| 939 |
-
"""Retrieve a connected subgraph starting from nodes matching the given label
|
| 940 |
-
|
| 941 |
-
Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable.
|
| 942 |
-
Prioritizes nodes by:
|
| 943 |
-
1. Nodes matching the specified label
|
| 944 |
-
2. Nodes directly connected to matching nodes
|
| 945 |
-
3. Node degree (number of connections)
|
| 946 |
-
|
| 947 |
-
Args:
|
| 948 |
-
node_label: Label to match for starting nodes (use "*" for all nodes)
|
| 949 |
-
max_depth: Maximum depth of traversal from starting nodes
|
| 950 |
-
|
| 951 |
-
Returns:
|
| 952 |
-
KnowledgeGraph object containing nodes and edges
|
| 953 |
-
"""
|
| 954 |
-
result = KnowledgeGraph()
|
| 955 |
-
|
| 956 |
-
try:
|
| 957 |
-
# Define maximum number of nodes to return
|
| 958 |
-
max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000))
|
| 959 |
-
|
| 960 |
-
if node_label == "*":
|
| 961 |
-
# For "*" label, get all nodes up to the limit
|
| 962 |
-
nodes_sql = """
|
| 963 |
-
SELECT name, entity_type, description, source_chunk_id
|
| 964 |
-
FROM LIGHTRAG_GRAPH_NODES
|
| 965 |
-
WHERE workspace = :workspace
|
| 966 |
-
ORDER BY id
|
| 967 |
-
FETCH FIRST :limit ROWS ONLY
|
| 968 |
-
"""
|
| 969 |
-
nodes_params = {
|
| 970 |
-
"workspace": self.db.workspace,
|
| 971 |
-
"limit": max_graph_nodes,
|
| 972 |
-
}
|
| 973 |
-
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
|
| 974 |
-
else:
|
| 975 |
-
# For specific label, find matching nodes and related nodes
|
| 976 |
-
nodes_sql = """
|
| 977 |
-
WITH matching_nodes AS (
|
| 978 |
-
SELECT name
|
| 979 |
-
FROM LIGHTRAG_GRAPH_NODES
|
| 980 |
-
WHERE workspace = :workspace
|
| 981 |
-
AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%')
|
| 982 |
-
)
|
| 983 |
-
SELECT n.name, n.entity_type, n.description, n.source_chunk_id,
|
| 984 |
-
CASE
|
| 985 |
-
WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2
|
| 986 |
-
WHEN EXISTS (
|
| 987 |
-
SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e
|
| 988 |
-
WHERE workspace = :workspace
|
| 989 |
-
AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes))
|
| 990 |
-
OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes)))
|
| 991 |
-
) THEN 1
|
| 992 |
-
ELSE 0
|
| 993 |
-
END AS priority,
|
| 994 |
-
(SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e
|
| 995 |
-
WHERE workspace = :workspace
|
| 996 |
-
AND (e.source_name = n.name OR e.target_name = n.name)) AS degree
|
| 997 |
-
FROM LIGHTRAG_GRAPH_NODES n
|
| 998 |
-
WHERE workspace = :workspace
|
| 999 |
-
ORDER BY priority DESC, degree DESC
|
| 1000 |
-
FETCH FIRST :limit ROWS ONLY
|
| 1001 |
-
"""
|
| 1002 |
-
nodes_params = {
|
| 1003 |
-
"workspace": self.db.workspace,
|
| 1004 |
-
"node_label": node_label,
|
| 1005 |
-
"limit": max_graph_nodes,
|
| 1006 |
-
}
|
| 1007 |
-
nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
|
| 1008 |
-
|
| 1009 |
-
if not nodes:
|
| 1010 |
-
logger.warning(f"No nodes found matching '{node_label}'")
|
| 1011 |
-
return result
|
| 1012 |
-
|
| 1013 |
-
# Create mapping of node IDs to be used to filter edges
|
| 1014 |
-
node_names = [node["name"] for node in nodes]
|
| 1015 |
-
|
| 1016 |
-
# Add nodes to result
|
| 1017 |
-
seen_nodes = set()
|
| 1018 |
-
for node in nodes:
|
| 1019 |
-
node_id = node["name"]
|
| 1020 |
-
if node_id in seen_nodes:
|
| 1021 |
-
continue
|
| 1022 |
-
|
| 1023 |
-
# Create node properties dictionary
|
| 1024 |
-
properties = {
|
| 1025 |
-
"entity_type": node["entity_type"],
|
| 1026 |
-
"description": node["description"] or "",
|
| 1027 |
-
"source_id": node["source_chunk_id"] or "",
|
| 1028 |
-
}
|
| 1029 |
-
|
| 1030 |
-
# Add node to result
|
| 1031 |
-
result.nodes.append(
|
| 1032 |
-
KnowledgeGraphNode(
|
| 1033 |
-
id=node_id, labels=[node["entity_type"]], properties=properties
|
| 1034 |
-
)
|
| 1035 |
-
)
|
| 1036 |
-
seen_nodes.add(node_id)
|
| 1037 |
-
|
| 1038 |
-
# Get edges between these nodes
|
| 1039 |
-
edges_sql = """
|
| 1040 |
-
SELECT source_name, target_name, weight, keywords, description, source_chunk_id
|
| 1041 |
-
FROM LIGHTRAG_GRAPH_EDGES
|
| 1042 |
-
WHERE workspace = :workspace
|
| 1043 |
-
AND source_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
|
| 1044 |
-
AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
|
| 1045 |
-
ORDER BY id
|
| 1046 |
-
"""
|
| 1047 |
-
edges_params = {"workspace": self.db.workspace, "node_names": node_names}
|
| 1048 |
-
edges = await self.db.query(edges_sql, edges_params, multirows=True)
|
| 1049 |
-
|
| 1050 |
-
# Add edges to result
|
| 1051 |
-
seen_edges = set()
|
| 1052 |
-
for edge in edges:
|
| 1053 |
-
source = edge["source_name"]
|
| 1054 |
-
target = edge["target_name"]
|
| 1055 |
-
edge_id = f"{source}-{target}"
|
| 1056 |
-
|
| 1057 |
-
if edge_id in seen_edges:
|
| 1058 |
-
continue
|
| 1059 |
-
|
| 1060 |
-
# Create edge properties dictionary
|
| 1061 |
-
properties = {
|
| 1062 |
-
"weight": edge["weight"] or 0.0,
|
| 1063 |
-
"keywords": edge["keywords"] or "",
|
| 1064 |
-
"description": edge["description"] or "",
|
| 1065 |
-
"source_id": edge["source_chunk_id"] or "",
|
| 1066 |
-
}
|
| 1067 |
-
|
| 1068 |
-
# Add edge to result
|
| 1069 |
-
result.edges.append(
|
| 1070 |
-
KnowledgeGraphEdge(
|
| 1071 |
-
id=edge_id,
|
| 1072 |
-
type="RELATED",
|
| 1073 |
-
source=source,
|
| 1074 |
-
target=target,
|
| 1075 |
-
properties=properties,
|
| 1076 |
-
)
|
| 1077 |
-
)
|
| 1078 |
-
seen_edges.add(edge_id)
|
| 1079 |
-
|
| 1080 |
-
logger.info(
|
| 1081 |
-
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
| 1082 |
-
)
|
| 1083 |
-
|
| 1084 |
-
except Exception as e:
|
| 1085 |
-
logger.error(f"Error retrieving knowledge graph: {e}")
|
| 1086 |
-
|
| 1087 |
-
return result
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
N_T = {
|
| 1091 |
-
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
| 1092 |
-
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
| 1093 |
-
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
|
| 1094 |
-
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
|
| 1095 |
-
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
|
| 1096 |
-
}
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
def namespace_to_table_name(namespace: str) -> str:
|
| 1100 |
-
for k, v in N_T.items():
|
| 1101 |
-
if is_namespace(namespace, k):
|
| 1102 |
-
return v
|
| 1103 |
-
|
| 1104 |
-
|
| 1105 |
-
TABLES = {
|
| 1106 |
-
"LIGHTRAG_DOC_FULL": {
|
| 1107 |
-
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
|
| 1108 |
-
id varchar(256),
|
| 1109 |
-
workspace varchar(1024),
|
| 1110 |
-
doc_name varchar(1024),
|
| 1111 |
-
content CLOB,
|
| 1112 |
-
meta JSON,
|
| 1113 |
-
content_summary varchar(1024),
|
| 1114 |
-
content_length NUMBER,
|
| 1115 |
-
status varchar(256),
|
| 1116 |
-
chunks_count NUMBER,
|
| 1117 |
-
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 1118 |
-
updatetime TIMESTAMP DEFAULT NULL,
|
| 1119 |
-
error varchar(4096)
|
| 1120 |
-
)"""
|
| 1121 |
-
},
|
| 1122 |
-
"LIGHTRAG_DOC_CHUNKS": {
|
| 1123 |
-
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
| 1124 |
-
id varchar(256),
|
| 1125 |
-
workspace varchar(1024),
|
| 1126 |
-
full_doc_id varchar(256),
|
| 1127 |
-
status varchar(256),
|
| 1128 |
-
chunk_order_index NUMBER,
|
| 1129 |
-
tokens NUMBER,
|
| 1130 |
-
content CLOB,
|
| 1131 |
-
content_vector VECTOR,
|
| 1132 |
-
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 1133 |
-
updatetime TIMESTAMP DEFAULT NULL
|
| 1134 |
-
)"""
|
| 1135 |
-
},
|
| 1136 |
-
"LIGHTRAG_GRAPH_NODES": {
|
| 1137 |
-
"ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
| 1138 |
-
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
| 1139 |
-
workspace varchar(1024),
|
| 1140 |
-
name varchar(2048),
|
| 1141 |
-
entity_type varchar(1024),
|
| 1142 |
-
description CLOB,
|
| 1143 |
-
source_chunk_id varchar(256),
|
| 1144 |
-
content CLOB,
|
| 1145 |
-
content_vector VECTOR,
|
| 1146 |
-
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 1147 |
-
updatetime TIMESTAMP DEFAULT NULL
|
| 1148 |
-
)"""
|
| 1149 |
-
},
|
| 1150 |
-
"LIGHTRAG_GRAPH_EDGES": {
|
| 1151 |
-
"ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
| 1152 |
-
id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
|
| 1153 |
-
workspace varchar(1024),
|
| 1154 |
-
source_name varchar(2048),
|
| 1155 |
-
target_name varchar(2048),
|
| 1156 |
-
weight NUMBER,
|
| 1157 |
-
keywords CLOB,
|
| 1158 |
-
description CLOB,
|
| 1159 |
-
source_chunk_id varchar(256),
|
| 1160 |
-
content CLOB,
|
| 1161 |
-
content_vector VECTOR,
|
| 1162 |
-
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 1163 |
-
updatetime TIMESTAMP DEFAULT NULL
|
| 1164 |
-
)"""
|
| 1165 |
-
},
|
| 1166 |
-
"LIGHTRAG_LLM_CACHE": {
|
| 1167 |
-
"ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
|
| 1168 |
-
id varchar(256) PRIMARY KEY,
|
| 1169 |
-
workspace varchar(1024),
|
| 1170 |
-
cache_mode varchar(256),
|
| 1171 |
-
model_name varchar(256),
|
| 1172 |
-
original_prompt clob,
|
| 1173 |
-
return_value clob,
|
| 1174 |
-
embedding CLOB,
|
| 1175 |
-
embedding_shape NUMBER,
|
| 1176 |
-
embedding_min NUMBER,
|
| 1177 |
-
embedding_max NUMBER,
|
| 1178 |
-
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 1179 |
-
updatetime TIMESTAMP DEFAULT NULL
|
| 1180 |
-
)"""
|
| 1181 |
-
},
|
| 1182 |
-
"LIGHTRAG_GRAPH": {
|
| 1183 |
-
"ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
|
| 1184 |
-
VERTEX TABLES (
|
| 1185 |
-
lightrag_graph_nodes KEY (id)
|
| 1186 |
-
LABEL entity
|
| 1187 |
-
PROPERTIES (id,workspace,name) -- ,entity_type,description,source_chunk_id)
|
| 1188 |
-
)
|
| 1189 |
-
EDGE TABLES (
|
| 1190 |
-
lightrag_graph_edges KEY (id)
|
| 1191 |
-
SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
|
| 1192 |
-
DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
|
| 1193 |
-
LABEL has_relation
|
| 1194 |
-
PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
|
| 1195 |
-
) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""
|
| 1196 |
-
},
|
| 1197 |
-
}
|
| 1198 |
-
|
| 1199 |
-
|
| 1200 |
-
SQL_TEMPLATES = {
|
| 1201 |
-
# SQL for KVStorage
|
| 1202 |
-
"get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
|
| 1203 |
-
"get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
|
| 1204 |
-
"get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
|
| 1205 |
-
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
|
| 1206 |
-
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
|
| 1207 |
-
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
|
| 1208 |
-
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
|
| 1209 |
-
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
|
| 1210 |
-
"get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
|
| 1211 |
-
"get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
|
| 1212 |
-
"get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
|
| 1213 |
-
"get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
|
| 1214 |
-
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
|
| 1215 |
-
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
|
| 1216 |
-
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
|
| 1217 |
-
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
|
| 1218 |
-
USING DUAL
|
| 1219 |
-
ON (a.id = :id and a.workspace = :workspace)
|
| 1220 |
-
WHEN NOT MATCHED THEN
|
| 1221 |
-
INSERT(id,content,workspace) values(:id,:content,:workspace)""",
|
| 1222 |
-
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
|
| 1223 |
-
USING DUAL
|
| 1224 |
-
ON (id = :id and workspace = :workspace)
|
| 1225 |
-
WHEN NOT MATCHED THEN INSERT
|
| 1226 |
-
(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
|
| 1227 |
-
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
|
| 1228 |
-
"upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
|
| 1229 |
-
USING DUAL
|
| 1230 |
-
ON (a.id = :id)
|
| 1231 |
-
WHEN NOT MATCHED THEN
|
| 1232 |
-
INSERT (workspace,id,original_prompt,return_value,cache_mode)
|
| 1233 |
-
VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode)
|
| 1234 |
-
WHEN MATCHED THEN UPDATE
|
| 1235 |
-
SET original_prompt = :original_prompt,
|
| 1236 |
-
return_value = :return_value,
|
| 1237 |
-
cache_mode = :cache_mode,
|
| 1238 |
-
updatetime = SYSDATE""",
|
| 1239 |
-
# SQL for VectorStorage
|
| 1240 |
-
"entities": """SELECT name as entity_name FROM
|
| 1241 |
-
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
| 1242 |
-
FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace)
|
| 1243 |
-
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
| 1244 |
-
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
|
| 1245 |
-
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
| 1246 |
-
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace)
|
| 1247 |
-
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
| 1248 |
-
"chunks": """SELECT id FROM
|
| 1249 |
-
(SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
| 1250 |
-
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace)
|
| 1251 |
-
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
| 1252 |
-
# SQL for GraphStorage
|
| 1253 |
-
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
| 1254 |
-
MATCH (a)
|
| 1255 |
-
WHERE a.workspace=:workspace AND a.name=:node_id
|
| 1256 |
-
COLUMNS (a.name))""",
|
| 1257 |
-
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
| 1258 |
-
MATCH (a) -[e]-> (b)
|
| 1259 |
-
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
| 1260 |
-
AND a.name=:source_node_id AND b.name=:target_node_id
|
| 1261 |
-
COLUMNS (e.source_name,e.target_name) )""",
|
| 1262 |
-
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
| 1263 |
-
MATCH (a)-[e]->(b)
|
| 1264 |
-
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
| 1265 |
-
AND a.name=:node_id or b.name = :node_id
|
| 1266 |
-
COLUMNS (a.name))""",
|
| 1267 |
-
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
| 1268 |
-
FROM GRAPH_TABLE (lightrag_graph
|
| 1269 |
-
MATCH (a)
|
| 1270 |
-
WHERE a.workspace=:workspace AND a.name=:node_id
|
| 1271 |
-
COLUMNS (a.name)
|
| 1272 |
-
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
| 1273 |
-
WHERE t2.workspace=:workspace""",
|
| 1274 |
-
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
| 1275 |
-
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
| 1276 |
-
FROM GRAPH_TABLE (lightrag_graph
|
| 1277 |
-
MATCH (a)-[e]->(b)
|
| 1278 |
-
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
| 1279 |
-
AND a.name=:source_node_id and b.name = :target_node_id
|
| 1280 |
-
COLUMNS (e.id,a.name as source_id)
|
| 1281 |
-
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
| 1282 |
-
"get_node_edges": """SELECT source_name,target_name
|
| 1283 |
-
FROM GRAPH_TABLE (lightrag_graph
|
| 1284 |
-
MATCH (a)-[e]->(b)
|
| 1285 |
-
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
| 1286 |
-
AND a.name=:source_node_id
|
| 1287 |
-
COLUMNS (a.name as source_name,b.name as target_name))""",
|
| 1288 |
-
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
| 1289 |
-
USING DUAL
|
| 1290 |
-
ON (a.workspace=:workspace and a.name=:name)
|
| 1291 |
-
WHEN NOT MATCHED THEN
|
| 1292 |
-
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
|
| 1293 |
-
values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector)
|
| 1294 |
-
WHEN MATCHED THEN
|
| 1295 |
-
UPDATE SET
|
| 1296 |
-
entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
|
| 1297 |
-
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
|
| 1298 |
-
USING DUAL
|
| 1299 |
-
ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name)
|
| 1300 |
-
WHEN NOT MATCHED THEN
|
| 1301 |
-
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
| 1302 |
-
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector)
|
| 1303 |
-
WHEN MATCHED THEN
|
| 1304 |
-
UPDATE SET
|
| 1305 |
-
weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
|
| 1306 |
-
"get_all_nodes": """WITH t0 AS (
|
| 1307 |
-
SELECT name AS id, entity_type AS label, entity_type, description,
|
| 1308 |
-
'["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
|
| 1309 |
-
FROM lightrag_graph_nodes
|
| 1310 |
-
WHERE workspace = :workspace
|
| 1311 |
-
ORDER BY createtime DESC fetch first :limit rows only
|
| 1312 |
-
), t1 AS (
|
| 1313 |
-
SELECT t0.id, source_chunk_id
|
| 1314 |
-
FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
|
| 1315 |
-
), t2 AS (
|
| 1316 |
-
SELECT t1.id, LISTAGG(t2.content, '\n') content
|
| 1317 |
-
FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
|
| 1318 |
-
GROUP BY t1.id
|
| 1319 |
-
)
|
| 1320 |
-
SELECT t0.id, label, entity_type, description, t2.content
|
| 1321 |
-
FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
|
| 1322 |
-
"get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
|
| 1323 |
-
t1.weight,t1.DESCRIPTION,t2.content
|
| 1324 |
-
FROM LIGHTRAG_GRAPH_EDGES t1
|
| 1325 |
-
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
| 1326 |
-
WHERE t1.workspace=:workspace
|
| 1327 |
-
order by t1.CREATETIME DESC
|
| 1328 |
-
fetch first :limit rows only""",
|
| 1329 |
-
"get_statistics": """select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
|
| 1330 |
-
count(distinct CASE WHEN type='edge' THEN id END) as edges_count
|
| 1331 |
-
FROM (
|
| 1332 |
-
select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
|
| 1333 |
-
MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
|
| 1334 |
-
UNION
|
| 1335 |
-
select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
|
| 1336 |
-
MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
|
| 1337 |
-
)""",
|
| 1338 |
-
# SQL for deletion
|
| 1339 |
-
"delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})",
|
| 1340 |
-
"delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name",
|
| 1341 |
-
"delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)",
|
| 1342 |
-
"delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph
|
| 1343 |
-
MATCH (a)
|
| 1344 |
-
WHERE a.workspace=:workspace AND a.name=:node_id
|
| 1345 |
-
ACTION DELETE a)""",
|
| 1346 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lightrag/kg/postgres_impl.py
CHANGED
|
@@ -9,7 +9,6 @@ import configparser
|
|
| 9 |
|
| 10 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
| 11 |
|
| 12 |
-
import sys
|
| 13 |
from tenacity import (
|
| 14 |
retry,
|
| 15 |
retry_if_exception_type,
|
|
@@ -28,11 +27,6 @@ from ..base import (
|
|
| 28 |
from ..namespace import NameSpace, is_namespace
|
| 29 |
from ..utils import logger
|
| 30 |
|
| 31 |
-
if sys.platform.startswith("win"):
|
| 32 |
-
import asyncio.windows_events
|
| 33 |
-
|
| 34 |
-
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
| 35 |
-
|
| 36 |
import pipmaster as pm
|
| 37 |
|
| 38 |
if not pm.is_installed("asyncpg"):
|
|
@@ -41,6 +35,9 @@ if not pm.is_installed("asyncpg"):
|
|
| 41 |
import asyncpg # type: ignore
|
| 42 |
from asyncpg import Pool # type: ignore
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
class PostgreSQLDB:
|
| 46 |
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
|
@@ -118,6 +115,25 @@ class PostgreSQLDB:
|
|
| 118 |
)
|
| 119 |
raise e
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
async def query(
|
| 122 |
self,
|
| 123 |
sql: str,
|
|
@@ -254,8 +270,6 @@ class PGKVStorage(BaseKVStorage):
|
|
| 254 |
db: PostgreSQLDB = field(default=None)
|
| 255 |
|
| 256 |
def __post_init__(self):
|
| 257 |
-
namespace_prefix = self.global_config.get("namespace_prefix")
|
| 258 |
-
self.base_namespace = self.namespace.replace(namespace_prefix, "")
|
| 259 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
| 260 |
|
| 261 |
async def initialize(self):
|
|
@@ -271,7 +285,7 @@ class PGKVStorage(BaseKVStorage):
|
|
| 271 |
|
| 272 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
| 273 |
"""Get doc_full data by id."""
|
| 274 |
-
sql = SQL_TEMPLATES["get_by_id_" + self.
|
| 275 |
params = {"workspace": self.db.workspace, "id": id}
|
| 276 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
| 277 |
array_res = await self.db.query(sql, params, multirows=True)
|
|
@@ -285,7 +299,7 @@ class PGKVStorage(BaseKVStorage):
|
|
| 285 |
|
| 286 |
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
| 287 |
"""Specifically for llm_response_cache."""
|
| 288 |
-
sql = SQL_TEMPLATES["get_by_mode_id_" + self.
|
| 289 |
params = {"workspace": self.db.workspace, mode: mode, "id": id}
|
| 290 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
| 291 |
array_res = await self.db.query(sql, params, multirows=True)
|
|
@@ -299,7 +313,7 @@ class PGKVStorage(BaseKVStorage):
|
|
| 299 |
# Query by id
|
| 300 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
| 301 |
"""Get doc_chunks data by id"""
|
| 302 |
-
sql = SQL_TEMPLATES["get_by_ids_" + self.
|
| 303 |
ids=",".join([f"'{id}'" for id in ids])
|
| 304 |
)
|
| 305 |
params = {"workspace": self.db.workspace}
|
|
@@ -320,7 +334,7 @@ class PGKVStorage(BaseKVStorage):
|
|
| 320 |
|
| 321 |
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
| 322 |
"""Specifically for llm_response_cache."""
|
| 323 |
-
SQL = SQL_TEMPLATES["get_by_status_" + self.
|
| 324 |
params = {"workspace": self.db.workspace, "status": status}
|
| 325 |
return await self.db.query(SQL, params, multirows=True)
|
| 326 |
|
|
@@ -380,10 +394,85 @@ class PGKVStorage(BaseKVStorage):
|
|
| 380 |
# PG handles persistence automatically
|
| 381 |
pass
|
| 382 |
|
| 383 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
"""Drop the storage"""
|
| 385 |
-
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
|
| 389 |
@final
|
|
@@ -393,8 +482,6 @@ class PGVectorStorage(BaseVectorStorage):
|
|
| 393 |
|
| 394 |
def __post_init__(self):
|
| 395 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
| 396 |
-
namespace_prefix = self.global_config.get("namespace_prefix")
|
| 397 |
-
self.base_namespace = self.namespace.replace(namespace_prefix, "")
|
| 398 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
| 399 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
| 400 |
if cosine_threshold is None:
|
|
@@ -523,7 +610,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|
| 523 |
else:
|
| 524 |
formatted_ids = "NULL"
|
| 525 |
|
| 526 |
-
sql = SQL_TEMPLATES[self.
|
| 527 |
embedding_string=embedding_string, doc_ids=formatted_ids
|
| 528 |
)
|
| 529 |
params = {
|
|
@@ -552,13 +639,12 @@ class PGVectorStorage(BaseVectorStorage):
|
|
| 552 |
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
|
| 553 |
return
|
| 554 |
|
| 555 |
-
|
| 556 |
-
delete_sql = (
|
| 557 |
-
f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
|
| 558 |
-
)
|
| 559 |
|
| 560 |
try:
|
| 561 |
-
await self.db.execute(
|
|
|
|
|
|
|
| 562 |
logger.debug(
|
| 563 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
| 564 |
)
|
|
@@ -690,6 +776,24 @@ class PGVectorStorage(BaseVectorStorage):
|
|
| 690 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
| 691 |
return []
|
| 692 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 693 |
|
| 694 |
@final
|
| 695 |
@dataclass
|
|
@@ -810,6 +914,35 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
| 810 |
# PG handles persistence automatically
|
| 811 |
pass
|
| 812 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 814 |
"""Update or insert document status
|
| 815 |
|
|
@@ -846,10 +979,23 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
| 846 |
},
|
| 847 |
)
|
| 848 |
|
| 849 |
-
async def drop(self) ->
|
| 850 |
"""Drop the storage"""
|
| 851 |
-
|
| 852 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
|
| 854 |
|
| 855 |
class PGGraphQueryException(Exception):
|
|
@@ -937,31 +1083,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 937 |
if v.startswith("[") and v.endswith("]"):
|
| 938 |
if "::vertex" in v:
|
| 939 |
v = v.replace("::vertex", "")
|
| 940 |
-
|
| 941 |
-
dl = []
|
| 942 |
-
for vertex in vertexes:
|
| 943 |
-
prop = vertex.get("properties")
|
| 944 |
-
if not prop:
|
| 945 |
-
prop = {}
|
| 946 |
-
prop["label"] = PGGraphStorage._decode_graph_label(
|
| 947 |
-
prop["node_id"]
|
| 948 |
-
)
|
| 949 |
-
dl.append(prop)
|
| 950 |
-
d[k] = dl
|
| 951 |
|
| 952 |
elif "::edge" in v:
|
| 953 |
v = v.replace("::edge", "")
|
| 954 |
-
|
| 955 |
-
dl = []
|
| 956 |
-
for edge in edges:
|
| 957 |
-
dl.append(
|
| 958 |
-
(
|
| 959 |
-
vertices[edge["start_id"]],
|
| 960 |
-
edge["label"],
|
| 961 |
-
vertices[edge["end_id"]],
|
| 962 |
-
)
|
| 963 |
-
)
|
| 964 |
-
d[k] = dl
|
| 965 |
else:
|
| 966 |
print("WARNING: unsupported type")
|
| 967 |
continue
|
|
@@ -970,32 +1096,19 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 970 |
dtype = v.split("::")[-1]
|
| 971 |
v = v.split("::")[0]
|
| 972 |
if dtype == "vertex":
|
| 973 |
-
|
| 974 |
-
field = vertex.get("properties")
|
| 975 |
-
if not field:
|
| 976 |
-
field = {}
|
| 977 |
-
field["label"] = PGGraphStorage._decode_graph_label(
|
| 978 |
-
field["node_id"]
|
| 979 |
-
)
|
| 980 |
-
d[k] = field
|
| 981 |
-
# convert edge from id-label->id by replacing id with node information
|
| 982 |
-
# we only do this if the vertex was also returned in the query
|
| 983 |
-
# this is an attempt to be consistent with neo4j implementation
|
| 984 |
elif dtype == "edge":
|
| 985 |
-
|
| 986 |
-
d[k] = (
|
| 987 |
-
vertices.get(edge["start_id"], {}),
|
| 988 |
-
edge[
|
| 989 |
-
"label"
|
| 990 |
-
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
|
| 991 |
-
vertices.get(edge["end_id"], {}),
|
| 992 |
-
)
|
| 993 |
else:
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 999 |
|
| 1000 |
return d
|
| 1001 |
|
|
@@ -1025,56 +1138,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1025 |
)
|
| 1026 |
return "{" + ", ".join(props) + "}"
|
| 1027 |
|
| 1028 |
-
@staticmethod
|
| 1029 |
-
def _encode_graph_label(label: str) -> str:
|
| 1030 |
-
"""
|
| 1031 |
-
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
| 1032 |
-
|
| 1033 |
-
Args:
|
| 1034 |
-
label (str): the original label
|
| 1035 |
-
|
| 1036 |
-
Returns:
|
| 1037 |
-
str: the encoded label
|
| 1038 |
-
"""
|
| 1039 |
-
return "x" + label.encode().hex()
|
| 1040 |
-
|
| 1041 |
-
@staticmethod
|
| 1042 |
-
def _decode_graph_label(encoded_label: str) -> str:
|
| 1043 |
-
"""
|
| 1044 |
-
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
| 1045 |
-
|
| 1046 |
-
Args:
|
| 1047 |
-
encoded_label (str): the encoded label
|
| 1048 |
-
|
| 1049 |
-
Returns:
|
| 1050 |
-
str: the decoded label
|
| 1051 |
-
"""
|
| 1052 |
-
return bytes.fromhex(encoded_label.removeprefix("x")).decode()
|
| 1053 |
-
|
| 1054 |
-
@staticmethod
|
| 1055 |
-
def _get_col_name(field: str, idx: int) -> str:
|
| 1056 |
-
"""
|
| 1057 |
-
Convert a cypher return field to a pgsql select field
|
| 1058 |
-
If possible keep the cypher column name, but create a generic name if necessary
|
| 1059 |
-
|
| 1060 |
-
Args:
|
| 1061 |
-
field (str): a return field from a cypher query to be formatted for pgsql
|
| 1062 |
-
idx (int): the position of the field in the return statement
|
| 1063 |
-
|
| 1064 |
-
Returns:
|
| 1065 |
-
str: the field to be used in the pgsql select statement
|
| 1066 |
-
"""
|
| 1067 |
-
# remove white space
|
| 1068 |
-
field = field.strip()
|
| 1069 |
-
# if an alias is provided for the field, use it
|
| 1070 |
-
if " as " in field:
|
| 1071 |
-
return field.split(" as ")[-1].strip()
|
| 1072 |
-
# if the return value is an unnamed primitive, give it a generic name
|
| 1073 |
-
if field.isnumeric() or field in ("true", "false", "null"):
|
| 1074 |
-
return f"column_{idx}"
|
| 1075 |
-
# otherwise return the value stripping out some common special chars
|
| 1076 |
-
return field.replace("(", "_").replace(")", "")
|
| 1077 |
-
|
| 1078 |
async def _query(
|
| 1079 |
self,
|
| 1080 |
query: str,
|
|
@@ -1125,10 +1188,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1125 |
return result
|
| 1126 |
|
| 1127 |
async def has_node(self, node_id: str) -> bool:
|
| 1128 |
-
entity_name_label =
|
| 1129 |
|
| 1130 |
query = """SELECT * FROM cypher('%s', $$
|
| 1131 |
-
MATCH (n:
|
| 1132 |
RETURN count(n) > 0 AS node_exists
|
| 1133 |
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
| 1134 |
|
|
@@ -1137,11 +1200,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1137 |
return single_result["node_exists"]
|
| 1138 |
|
| 1139 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 1140 |
-
src_label =
|
| 1141 |
-
tgt_label =
|
| 1142 |
|
| 1143 |
query = """SELECT * FROM cypher('%s', $$
|
| 1144 |
-
MATCH (a:
|
| 1145 |
RETURN COUNT(r) > 0 AS edge_exists
|
| 1146 |
$$) AS (edge_exists bool)""" % (
|
| 1147 |
self.graph_name,
|
|
@@ -1154,30 +1217,31 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1154 |
return single_result["edge_exists"]
|
| 1155 |
|
| 1156 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
| 1157 |
-
label
|
|
|
|
|
|
|
| 1158 |
query = """SELECT * FROM cypher('%s', $$
|
| 1159 |
-
MATCH (n:
|
| 1160 |
RETURN n
|
| 1161 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
| 1162 |
record = await self._query(query)
|
| 1163 |
if record:
|
| 1164 |
node = record[0]
|
| 1165 |
-
node_dict = node["n"]
|
| 1166 |
|
| 1167 |
return node_dict
|
| 1168 |
return None
|
| 1169 |
|
| 1170 |
async def node_degree(self, node_id: str) -> int:
|
| 1171 |
-
label =
|
| 1172 |
|
| 1173 |
query = """SELECT * FROM cypher('%s', $$
|
| 1174 |
-
MATCH (n:
|
| 1175 |
RETURN count(x) AS total_edge_count
|
| 1176 |
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
| 1177 |
record = (await self._query(query))[0]
|
| 1178 |
if record:
|
| 1179 |
edge_count = int(record["total_edge_count"])
|
| 1180 |
-
|
| 1181 |
return edge_count
|
| 1182 |
|
| 1183 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
|
@@ -1195,11 +1259,13 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1195 |
async def get_edge(
|
| 1196 |
self, source_node_id: str, target_node_id: str
|
| 1197 |
) -> dict[str, str] | None:
|
| 1198 |
-
|
| 1199 |
-
|
|
|
|
|
|
|
| 1200 |
|
| 1201 |
query = """SELECT * FROM cypher('%s', $$
|
| 1202 |
-
MATCH (a:
|
| 1203 |
RETURN properties(r) as edge_properties
|
| 1204 |
LIMIT 1
|
| 1205 |
$$) AS (edge_properties agtype)""" % (
|
|
@@ -1218,11 +1284,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1218 |
Retrieves all edges (relationships) for a particular node identified by its label.
|
| 1219 |
:return: list of dictionaries containing edge information
|
| 1220 |
"""
|
| 1221 |
-
label =
|
| 1222 |
|
| 1223 |
query = """SELECT * FROM cypher('%s', $$
|
| 1224 |
-
MATCH (n:
|
| 1225 |
-
OPTIONAL MATCH (n)-[]-(connected)
|
| 1226 |
RETURN n, connected
|
| 1227 |
$$) AS (n agtype, connected agtype)""" % (
|
| 1228 |
self.graph_name,
|
|
@@ -1235,24 +1301,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1235 |
source_node = record["n"] if record["n"] else None
|
| 1236 |
connected_node = record["connected"] if record["connected"] else None
|
| 1237 |
|
| 1238 |
-
|
| 1239 |
-
source_node
|
| 1240 |
-
|
| 1241 |
-
|
| 1242 |
-
|
| 1243 |
-
|
| 1244 |
-
|
| 1245 |
-
|
| 1246 |
-
else None
|
| 1247 |
-
)
|
| 1248 |
|
| 1249 |
-
|
| 1250 |
-
|
| 1251 |
-
(
|
| 1252 |
-
self._decode_graph_label(source_label),
|
| 1253 |
-
self._decode_graph_label(target_label),
|
| 1254 |
-
)
|
| 1255 |
-
)
|
| 1256 |
|
| 1257 |
return edges
|
| 1258 |
|
|
@@ -1262,24 +1321,36 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1262 |
retry=retry_if_exception_type((PGGraphQueryException,)),
|
| 1263 |
)
|
| 1264 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
| 1265 |
-
|
| 1266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1267 |
|
| 1268 |
query = """SELECT * FROM cypher('%s', $$
|
| 1269 |
-
MERGE (n:
|
| 1270 |
SET n += %s
|
| 1271 |
RETURN n
|
| 1272 |
$$) AS (n agtype)""" % (
|
| 1273 |
self.graph_name,
|
| 1274 |
label,
|
| 1275 |
-
|
| 1276 |
)
|
| 1277 |
|
| 1278 |
try:
|
| 1279 |
await self._query(query, readonly=False, upsert=True)
|
| 1280 |
|
| 1281 |
-
except Exception
|
| 1282 |
-
logger.error("POSTGRES,
|
| 1283 |
raise
|
| 1284 |
|
| 1285 |
@retry(
|
|
@@ -1298,14 +1369,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1298 |
target_node_id (str): Label of the target node (used as identifier)
|
| 1299 |
edge_data (dict): dictionary of properties to set on the edge
|
| 1300 |
"""
|
| 1301 |
-
src_label =
|
| 1302 |
-
tgt_label =
|
| 1303 |
-
edge_properties = edge_data
|
| 1304 |
|
| 1305 |
query = """SELECT * FROM cypher('%s', $$
|
| 1306 |
-
MATCH (source:
|
| 1307 |
WITH source
|
| 1308 |
-
MATCH (target:
|
| 1309 |
MERGE (source)-[r:DIRECTED]->(target)
|
| 1310 |
SET r += %s
|
| 1311 |
RETURN r
|
|
@@ -1313,14 +1384,16 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1313 |
self.graph_name,
|
| 1314 |
src_label,
|
| 1315 |
tgt_label,
|
| 1316 |
-
|
| 1317 |
)
|
| 1318 |
|
| 1319 |
try:
|
| 1320 |
await self._query(query, readonly=False, upsert=True)
|
| 1321 |
|
| 1322 |
-
except Exception
|
| 1323 |
-
logger.error(
|
|
|
|
|
|
|
| 1324 |
raise
|
| 1325 |
|
| 1326 |
async def _node2vec_embed(self):
|
|
@@ -1333,10 +1406,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1333 |
Args:
|
| 1334 |
node_id (str): The ID of the node to delete.
|
| 1335 |
"""
|
| 1336 |
-
label =
|
| 1337 |
|
| 1338 |
query = """SELECT * FROM cypher('%s', $$
|
| 1339 |
-
MATCH (n:
|
| 1340 |
DETACH DELETE n
|
| 1341 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
| 1342 |
|
|
@@ -1353,14 +1426,12 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1353 |
Args:
|
| 1354 |
node_ids (list[str]): A list of node IDs to remove.
|
| 1355 |
"""
|
| 1356 |
-
|
| 1357 |
-
|
| 1358 |
-
]
|
| 1359 |
-
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
|
| 1360 |
|
| 1361 |
query = """SELECT * FROM cypher('%s', $$
|
| 1362 |
-
MATCH (n:
|
| 1363 |
-
WHERE n.
|
| 1364 |
DETACH DELETE n
|
| 1365 |
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
| 1366 |
|
|
@@ -1377,26 +1448,21 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1377 |
Args:
|
| 1378 |
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
| 1379 |
"""
|
| 1380 |
-
|
| 1381 |
-
(
|
| 1382 |
-
|
| 1383 |
-
self._encode_graph_label(tgt.strip('"')),
|
| 1384 |
-
)
|
| 1385 |
-
for src, tgt in edges
|
| 1386 |
-
]
|
| 1387 |
-
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
|
| 1388 |
|
| 1389 |
-
|
| 1390 |
-
|
| 1391 |
-
|
| 1392 |
-
|
| 1393 |
-
$$) AS (r agtype)""" % (self.graph_name, edge_list)
|
| 1394 |
|
| 1395 |
-
|
| 1396 |
-
|
| 1397 |
-
|
| 1398 |
-
|
| 1399 |
-
|
|
|
|
| 1400 |
|
| 1401 |
async def get_all_labels(self) -> list[str]:
|
| 1402 |
"""
|
|
@@ -1407,15 +1473,16 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1407 |
"""
|
| 1408 |
query = (
|
| 1409 |
"""SELECT * FROM cypher('%s', $$
|
| 1410 |
-
MATCH (n:
|
| 1411 |
-
|
|
|
|
|
|
|
| 1412 |
$$) AS (label text)"""
|
| 1413 |
% self.graph_name
|
| 1414 |
)
|
| 1415 |
|
| 1416 |
results = await self._query(query)
|
| 1417 |
-
labels = [
|
| 1418 |
-
|
| 1419 |
return labels
|
| 1420 |
|
| 1421 |
async def embed_nodes(
|
|
@@ -1437,105 +1504,135 @@ class PGGraphStorage(BaseGraphStorage):
|
|
| 1437 |
return await embed_func()
|
| 1438 |
|
| 1439 |
async def get_knowledge_graph(
|
| 1440 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 1441 |
) -> KnowledgeGraph:
|
| 1442 |
"""
|
| 1443 |
-
Retrieve a subgraph
|
| 1444 |
|
| 1445 |
Args:
|
| 1446 |
-
node_label
|
| 1447 |
-
max_depth
|
|
|
|
| 1448 |
|
| 1449 |
Returns:
|
| 1450 |
-
KnowledgeGraph
|
|
|
|
| 1451 |
"""
|
| 1452 |
-
|
| 1453 |
-
|
| 1454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1455 |
if node_label == "*":
|
| 1456 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1457 |
-
|
| 1458 |
-
|
| 1459 |
-
|
| 1460 |
-
|
| 1461 |
-
|
| 1462 |
else:
|
| 1463 |
-
|
| 1464 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1465 |
-
|
| 1466 |
-
|
| 1467 |
-
|
| 1468 |
-
|
| 1469 |
-
|
| 1470 |
|
| 1471 |
results = await self._query(query)
|
| 1472 |
|
| 1473 |
-
|
| 1474 |
-
|
| 1475 |
-
|
| 1476 |
-
|
| 1477 |
-
|
| 1478 |
-
|
| 1479 |
-
|
| 1480 |
-
|
| 1481 |
-
|
| 1482 |
-
|
| 1483 |
-
|
| 1484 |
-
|
| 1485 |
-
edge_key = f"{src_id},{tgt_id}"
|
| 1486 |
-
if edge_key not in unique_edge_ids:
|
| 1487 |
-
unique_edge_ids.add(edge_key)
|
| 1488 |
-
edges.append(
|
| 1489 |
-
(
|
| 1490 |
-
edge_key,
|
| 1491 |
-
src_id,
|
| 1492 |
-
tgt_id,
|
| 1493 |
-
{"source": edge_data[0], "target": edge_data[2]},
|
| 1494 |
)
|
| 1495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1496 |
|
| 1497 |
-
|
| 1498 |
-
|
| 1499 |
-
|
| 1500 |
-
if
|
| 1501 |
-
|
| 1502 |
-
|
| 1503 |
-
|
| 1504 |
-
|
| 1505 |
-
|
| 1506 |
-
|
| 1507 |
-
|
| 1508 |
-
|
| 1509 |
-
|
| 1510 |
-
for edge in result
|
| 1511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1512 |
|
| 1513 |
-
# Construct and return the KnowledgeGraph
|
| 1514 |
kg = KnowledgeGraph(
|
| 1515 |
-
nodes=
|
| 1516 |
-
|
| 1517 |
-
|
| 1518 |
-
],
|
| 1519 |
-
edges=[
|
| 1520 |
-
KnowledgeGraphEdge(
|
| 1521 |
-
id=edge_id,
|
| 1522 |
-
type="DIRECTED",
|
| 1523 |
-
source=src,
|
| 1524 |
-
target=tgt,
|
| 1525 |
-
properties=props,
|
| 1526 |
-
)
|
| 1527 |
-
for edge_id, src, tgt, props in edges
|
| 1528 |
-
],
|
| 1529 |
)
|
| 1530 |
|
|
|
|
|
|
|
|
|
|
| 1531 |
return kg
|
| 1532 |
|
| 1533 |
-
async def drop(self) ->
|
| 1534 |
"""Drop the storage"""
|
| 1535 |
-
|
| 1536 |
-
|
| 1537 |
-
|
| 1538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1539 |
|
| 1540 |
|
| 1541 |
NAMESPACE_TABLE_MAP = {
|
|
@@ -1648,7 +1745,7 @@ SQL_TEMPLATES = {
|
|
| 1648 |
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
|
| 1649 |
""",
|
| 1650 |
"get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
| 1651 |
-
chunk_order_index, full_doc_id
|
| 1652 |
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
| 1653 |
""",
|
| 1654 |
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
|
|
@@ -1661,7 +1758,7 @@ SQL_TEMPLATES = {
|
|
| 1661 |
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
|
| 1662 |
""",
|
| 1663 |
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
| 1664 |
-
chunk_order_index, full_doc_id
|
| 1665 |
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
|
| 1666 |
""",
|
| 1667 |
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
|
|
@@ -1693,6 +1790,7 @@ SQL_TEMPLATES = {
|
|
| 1693 |
file_path=EXCLUDED.file_path,
|
| 1694 |
update_time = CURRENT_TIMESTAMP
|
| 1695 |
""",
|
|
|
|
| 1696 |
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
| 1697 |
content_vector, chunk_ids, file_path)
|
| 1698 |
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
|
|
@@ -1716,45 +1814,6 @@ SQL_TEMPLATES = {
|
|
| 1716 |
file_path=EXCLUDED.file_path,
|
| 1717 |
update_time = CURRENT_TIMESTAMP
|
| 1718 |
""",
|
| 1719 |
-
# SQL for VectorStorage
|
| 1720 |
-
# "entities": """SELECT entity_name FROM
|
| 1721 |
-
# (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
| 1722 |
-
# FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
|
| 1723 |
-
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
| 1724 |
-
# """,
|
| 1725 |
-
# "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
|
| 1726 |
-
# (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
| 1727 |
-
# FROM LIGHTRAG_VDB_RELATION where workspace=$1)
|
| 1728 |
-
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
| 1729 |
-
# """,
|
| 1730 |
-
# "chunks": """SELECT id FROM
|
| 1731 |
-
# (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
| 1732 |
-
# FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
| 1733 |
-
# WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
| 1734 |
-
# """,
|
| 1735 |
-
# DROP tables
|
| 1736 |
-
"drop_all": """
|
| 1737 |
-
DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
|
| 1738 |
-
DROP TABLE IF EXISTS LIGHTRAG_DOC_CHUNKS CASCADE;
|
| 1739 |
-
DROP TABLE IF EXISTS LIGHTRAG_LLM_CACHE CASCADE;
|
| 1740 |
-
DROP TABLE IF EXISTS LIGHTRAG_VDB_ENTITY CASCADE;
|
| 1741 |
-
DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
|
| 1742 |
-
""",
|
| 1743 |
-
"drop_doc_full": """
|
| 1744 |
-
DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
|
| 1745 |
-
""",
|
| 1746 |
-
"drop_doc_chunks": """
|
| 1747 |
-
DROP TABLE IF EXISTS LIGHTRAG_DOC_CHUNKS CASCADE;
|
| 1748 |
-
""",
|
| 1749 |
-
"drop_llm_cache": """
|
| 1750 |
-
DROP TABLE IF EXISTS LIGHTRAG_LLM_CACHE CASCADE;
|
| 1751 |
-
""",
|
| 1752 |
-
"drop_vdb_entity": """
|
| 1753 |
-
DROP TABLE IF EXISTS LIGHTRAG_VDB_ENTITY CASCADE;
|
| 1754 |
-
""",
|
| 1755 |
-
"drop_vdb_relation": """
|
| 1756 |
-
DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
|
| 1757 |
-
""",
|
| 1758 |
"relationships": """
|
| 1759 |
WITH relevant_chunks AS (
|
| 1760 |
SELECT id as chunk_id
|
|
@@ -1795,9 +1854,9 @@ SQL_TEMPLATES = {
|
|
| 1795 |
FROM LIGHTRAG_DOC_CHUNKS
|
| 1796 |
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
| 1797 |
)
|
| 1798 |
-
SELECT id FROM
|
| 1799 |
(
|
| 1800 |
-
SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
| 1801 |
FROM LIGHTRAG_DOC_CHUNKS
|
| 1802 |
where workspace=$1
|
| 1803 |
AND id IN (SELECT chunk_id FROM relevant_chunks)
|
|
@@ -1806,4 +1865,8 @@ SQL_TEMPLATES = {
|
|
| 1806 |
ORDER BY distance DESC
|
| 1807 |
LIMIT $3
|
| 1808 |
""",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1809 |
}
|
|
|
|
| 9 |
|
| 10 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
| 11 |
|
|
|
|
| 12 |
from tenacity import (
|
| 13 |
retry,
|
| 14 |
retry_if_exception_type,
|
|
|
|
| 27 |
from ..namespace import NameSpace, is_namespace
|
| 28 |
from ..utils import logger
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
import pipmaster as pm
|
| 31 |
|
| 32 |
if not pm.is_installed("asyncpg"):
|
|
|
|
| 35 |
import asyncpg # type: ignore
|
| 36 |
from asyncpg import Pool # type: ignore
|
| 37 |
|
| 38 |
+
# Get maximum number of graph nodes from environment variable, default is 1000
|
| 39 |
+
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
|
| 40 |
+
|
| 41 |
|
| 42 |
class PostgreSQLDB:
|
| 43 |
def __init__(self, config: dict[str, Any], **kwargs: Any):
|
|
|
|
| 115 |
)
|
| 116 |
raise e
|
| 117 |
|
| 118 |
+
# Create index for id column in each table
|
| 119 |
+
try:
|
| 120 |
+
index_name = f"idx_{k.lower()}_id"
|
| 121 |
+
check_index_sql = f"""
|
| 122 |
+
SELECT 1 FROM pg_indexes
|
| 123 |
+
WHERE indexname = '{index_name}'
|
| 124 |
+
AND tablename = '{k.lower()}'
|
| 125 |
+
"""
|
| 126 |
+
index_exists = await self.query(check_index_sql)
|
| 127 |
+
|
| 128 |
+
if not index_exists:
|
| 129 |
+
create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)"
|
| 130 |
+
logger.info(f"PostgreSQL, Creating index {index_name} on table {k}")
|
| 131 |
+
await self.execute(create_index_sql)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.error(
|
| 134 |
+
f"PostgreSQL, Failed to create index on table {k}, Got: {e}"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
async def query(
|
| 138 |
self,
|
| 139 |
sql: str,
|
|
|
|
| 270 |
db: PostgreSQLDB = field(default=None)
|
| 271 |
|
| 272 |
def __post_init__(self):
|
|
|
|
|
|
|
| 273 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
| 274 |
|
| 275 |
async def initialize(self):
|
|
|
|
| 285 |
|
| 286 |
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
| 287 |
"""Get doc_full data by id."""
|
| 288 |
+
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
| 289 |
params = {"workspace": self.db.workspace, "id": id}
|
| 290 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
| 291 |
array_res = await self.db.query(sql, params, multirows=True)
|
|
|
|
| 299 |
|
| 300 |
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
| 301 |
"""Specifically for llm_response_cache."""
|
| 302 |
+
sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
|
| 303 |
params = {"workspace": self.db.workspace, mode: mode, "id": id}
|
| 304 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
| 305 |
array_res = await self.db.query(sql, params, multirows=True)
|
|
|
|
| 313 |
# Query by id
|
| 314 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
| 315 |
"""Get doc_chunks data by id"""
|
| 316 |
+
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
| 317 |
ids=",".join([f"'{id}'" for id in ids])
|
| 318 |
)
|
| 319 |
params = {"workspace": self.db.workspace}
|
|
|
|
| 334 |
|
| 335 |
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
| 336 |
"""Specifically for llm_response_cache."""
|
| 337 |
+
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
| 338 |
params = {"workspace": self.db.workspace, "status": status}
|
| 339 |
return await self.db.query(SQL, params, multirows=True)
|
| 340 |
|
|
|
|
| 394 |
# PG handles persistence automatically
|
| 395 |
pass
|
| 396 |
|
| 397 |
+
async def delete(self, ids: list[str]) -> None:
|
| 398 |
+
"""Delete specific records from storage by their IDs
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
ids (list[str]): List of document IDs to be deleted from storage
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
None
|
| 405 |
+
"""
|
| 406 |
+
if not ids:
|
| 407 |
+
return
|
| 408 |
+
|
| 409 |
+
table_name = namespace_to_table_name(self.namespace)
|
| 410 |
+
if not table_name:
|
| 411 |
+
logger.error(f"Unknown namespace for deletion: {self.namespace}")
|
| 412 |
+
return
|
| 413 |
+
|
| 414 |
+
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
| 415 |
+
|
| 416 |
+
try:
|
| 417 |
+
await self.db.execute(
|
| 418 |
+
delete_sql, {"workspace": self.db.workspace, "ids": ids}
|
| 419 |
+
)
|
| 420 |
+
logger.debug(
|
| 421 |
+
f"Successfully deleted {len(ids)} records from {self.namespace}"
|
| 422 |
+
)
|
| 423 |
+
except Exception as e:
|
| 424 |
+
logger.error(f"Error while deleting records from {self.namespace}: {e}")
|
| 425 |
+
|
| 426 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
| 427 |
+
"""Delete specific records from storage by cache mode
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
modes (list[str]): List of cache modes to be dropped from storage
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
bool: True if successful, False otherwise
|
| 434 |
+
"""
|
| 435 |
+
if not modes:
|
| 436 |
+
return False
|
| 437 |
+
|
| 438 |
+
try:
|
| 439 |
+
table_name = namespace_to_table_name(self.namespace)
|
| 440 |
+
if not table_name:
|
| 441 |
+
return False
|
| 442 |
+
|
| 443 |
+
if table_name != "LIGHTRAG_LLM_CACHE":
|
| 444 |
+
return False
|
| 445 |
+
|
| 446 |
+
sql = f"""
|
| 447 |
+
DELETE FROM {table_name}
|
| 448 |
+
WHERE workspace = $1 AND mode = ANY($2)
|
| 449 |
+
"""
|
| 450 |
+
params = {"workspace": self.db.workspace, "modes": modes}
|
| 451 |
+
|
| 452 |
+
logger.info(f"Deleting cache by modes: {modes}")
|
| 453 |
+
await self.db.execute(sql, params)
|
| 454 |
+
return True
|
| 455 |
+
except Exception as e:
|
| 456 |
+
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
| 457 |
+
return False
|
| 458 |
+
|
| 459 |
+
async def drop(self) -> dict[str, str]:
|
| 460 |
"""Drop the storage"""
|
| 461 |
+
try:
|
| 462 |
+
table_name = namespace_to_table_name(self.namespace)
|
| 463 |
+
if not table_name:
|
| 464 |
+
return {
|
| 465 |
+
"status": "error",
|
| 466 |
+
"message": f"Unknown namespace: {self.namespace}",
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
| 470 |
+
table_name=table_name
|
| 471 |
+
)
|
| 472 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
| 473 |
+
return {"status": "success", "message": "data dropped"}
|
| 474 |
+
except Exception as e:
|
| 475 |
+
return {"status": "error", "message": str(e)}
|
| 476 |
|
| 477 |
|
| 478 |
@final
|
|
|
|
| 482 |
|
| 483 |
def __post_init__(self):
|
| 484 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
|
|
|
|
|
|
| 485 |
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
| 486 |
cosine_threshold = config.get("cosine_better_than_threshold")
|
| 487 |
if cosine_threshold is None:
|
|
|
|
| 610 |
else:
|
| 611 |
formatted_ids = "NULL"
|
| 612 |
|
| 613 |
+
sql = SQL_TEMPLATES[self.namespace].format(
|
| 614 |
embedding_string=embedding_string, doc_ids=formatted_ids
|
| 615 |
)
|
| 616 |
params = {
|
|
|
|
| 639 |
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
|
| 640 |
return
|
| 641 |
|
| 642 |
+
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
|
|
|
|
|
|
|
|
|
| 643 |
|
| 644 |
try:
|
| 645 |
+
await self.db.execute(
|
| 646 |
+
delete_sql, {"workspace": self.db.workspace, "ids": ids}
|
| 647 |
+
)
|
| 648 |
logger.debug(
|
| 649 |
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
| 650 |
)
|
|
|
|
| 776 |
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
| 777 |
return []
|
| 778 |
|
| 779 |
+
async def drop(self) -> dict[str, str]:
|
| 780 |
+
"""Drop the storage"""
|
| 781 |
+
try:
|
| 782 |
+
table_name = namespace_to_table_name(self.namespace)
|
| 783 |
+
if not table_name:
|
| 784 |
+
return {
|
| 785 |
+
"status": "error",
|
| 786 |
+
"message": f"Unknown namespace: {self.namespace}",
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
| 790 |
+
table_name=table_name
|
| 791 |
+
)
|
| 792 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
| 793 |
+
return {"status": "success", "message": "data dropped"}
|
| 794 |
+
except Exception as e:
|
| 795 |
+
return {"status": "error", "message": str(e)}
|
| 796 |
+
|
| 797 |
|
| 798 |
@final
|
| 799 |
@dataclass
|
|
|
|
| 914 |
# PG handles persistence automatically
|
| 915 |
pass
|
| 916 |
|
| 917 |
+
async def delete(self, ids: list[str]) -> None:
|
| 918 |
+
"""Delete specific records from storage by their IDs
|
| 919 |
+
|
| 920 |
+
Args:
|
| 921 |
+
ids (list[str]): List of document IDs to be deleted from storage
|
| 922 |
+
|
| 923 |
+
Returns:
|
| 924 |
+
None
|
| 925 |
+
"""
|
| 926 |
+
if not ids:
|
| 927 |
+
return
|
| 928 |
+
|
| 929 |
+
table_name = namespace_to_table_name(self.namespace)
|
| 930 |
+
if not table_name:
|
| 931 |
+
logger.error(f"Unknown namespace for deletion: {self.namespace}")
|
| 932 |
+
return
|
| 933 |
+
|
| 934 |
+
delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
|
| 935 |
+
|
| 936 |
+
try:
|
| 937 |
+
await self.db.execute(
|
| 938 |
+
delete_sql, {"workspace": self.db.workspace, "ids": ids}
|
| 939 |
+
)
|
| 940 |
+
logger.debug(
|
| 941 |
+
f"Successfully deleted {len(ids)} records from {self.namespace}"
|
| 942 |
+
)
|
| 943 |
+
except Exception as e:
|
| 944 |
+
logger.error(f"Error while deleting records from {self.namespace}: {e}")
|
| 945 |
+
|
| 946 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 947 |
"""Update or insert document status
|
| 948 |
|
|
|
|
| 979 |
},
|
| 980 |
)
|
| 981 |
|
| 982 |
+
async def drop(self) -> dict[str, str]:
|
| 983 |
"""Drop the storage"""
|
| 984 |
+
try:
|
| 985 |
+
table_name = namespace_to_table_name(self.namespace)
|
| 986 |
+
if not table_name:
|
| 987 |
+
return {
|
| 988 |
+
"status": "error",
|
| 989 |
+
"message": f"Unknown namespace: {self.namespace}",
|
| 990 |
+
}
|
| 991 |
+
|
| 992 |
+
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
| 993 |
+
table_name=table_name
|
| 994 |
+
)
|
| 995 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
| 996 |
+
return {"status": "success", "message": "data dropped"}
|
| 997 |
+
except Exception as e:
|
| 998 |
+
return {"status": "error", "message": str(e)}
|
| 999 |
|
| 1000 |
|
| 1001 |
class PGGraphQueryException(Exception):
|
|
|
|
| 1083 |
if v.startswith("[") and v.endswith("]"):
|
| 1084 |
if "::vertex" in v:
|
| 1085 |
v = v.replace("::vertex", "")
|
| 1086 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1087 |
|
| 1088 |
elif "::edge" in v:
|
| 1089 |
v = v.replace("::edge", "")
|
| 1090 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1091 |
else:
|
| 1092 |
print("WARNING: unsupported type")
|
| 1093 |
continue
|
|
|
|
| 1096 |
dtype = v.split("::")[-1]
|
| 1097 |
v = v.split("::")[0]
|
| 1098 |
if dtype == "vertex":
|
| 1099 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1100 |
elif dtype == "edge":
|
| 1101 |
+
d[k] = json.loads(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1102 |
else:
|
| 1103 |
+
try:
|
| 1104 |
+
d[k] = (
|
| 1105 |
+
json.loads(v)
|
| 1106 |
+
if isinstance(v, str)
|
| 1107 |
+
and (v.startswith("{") or v.startswith("["))
|
| 1108 |
+
else v
|
| 1109 |
+
)
|
| 1110 |
+
except json.JSONDecodeError:
|
| 1111 |
+
d[k] = v
|
| 1112 |
|
| 1113 |
return d
|
| 1114 |
|
|
|
|
| 1138 |
)
|
| 1139 |
return "{" + ", ".join(props) + "}"
|
| 1140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1141 |
async def _query(
|
| 1142 |
self,
|
| 1143 |
query: str,
|
|
|
|
| 1188 |
return result
|
| 1189 |
|
| 1190 |
async def has_node(self, node_id: str) -> bool:
|
| 1191 |
+
entity_name_label = node_id.strip('"')
|
| 1192 |
|
| 1193 |
query = """SELECT * FROM cypher('%s', $$
|
| 1194 |
+
MATCH (n:base {entity_id: "%s"})
|
| 1195 |
RETURN count(n) > 0 AS node_exists
|
| 1196 |
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
| 1197 |
|
|
|
|
| 1200 |
return single_result["node_exists"]
|
| 1201 |
|
| 1202 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 1203 |
+
src_label = source_node_id.strip('"')
|
| 1204 |
+
tgt_label = target_node_id.strip('"')
|
| 1205 |
|
| 1206 |
query = """SELECT * FROM cypher('%s', $$
|
| 1207 |
+
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
|
| 1208 |
RETURN COUNT(r) > 0 AS edge_exists
|
| 1209 |
$$) AS (edge_exists bool)""" % (
|
| 1210 |
self.graph_name,
|
|
|
|
| 1217 |
return single_result["edge_exists"]
|
| 1218 |
|
| 1219 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
| 1220 |
+
"""Get node by its label identifier, return only node properties"""
|
| 1221 |
+
|
| 1222 |
+
label = node_id.strip('"')
|
| 1223 |
query = """SELECT * FROM cypher('%s', $$
|
| 1224 |
+
MATCH (n:base {entity_id: "%s"})
|
| 1225 |
RETURN n
|
| 1226 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
| 1227 |
record = await self._query(query)
|
| 1228 |
if record:
|
| 1229 |
node = record[0]
|
| 1230 |
+
node_dict = node["n"]["properties"]
|
| 1231 |
|
| 1232 |
return node_dict
|
| 1233 |
return None
|
| 1234 |
|
| 1235 |
async def node_degree(self, node_id: str) -> int:
|
| 1236 |
+
label = node_id.strip('"')
|
| 1237 |
|
| 1238 |
query = """SELECT * FROM cypher('%s', $$
|
| 1239 |
+
MATCH (n:base {entity_id: "%s"})-[]-(x)
|
| 1240 |
RETURN count(x) AS total_edge_count
|
| 1241 |
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
| 1242 |
record = (await self._query(query))[0]
|
| 1243 |
if record:
|
| 1244 |
edge_count = int(record["total_edge_count"])
|
|
|
|
| 1245 |
return edge_count
|
| 1246 |
|
| 1247 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
|
|
|
| 1259 |
async def get_edge(
|
| 1260 |
self, source_node_id: str, target_node_id: str
|
| 1261 |
) -> dict[str, str] | None:
|
| 1262 |
+
"""Get edge properties between two nodes"""
|
| 1263 |
+
|
| 1264 |
+
src_label = source_node_id.strip('"')
|
| 1265 |
+
tgt_label = target_node_id.strip('"')
|
| 1266 |
|
| 1267 |
query = """SELECT * FROM cypher('%s', $$
|
| 1268 |
+
MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
|
| 1269 |
RETURN properties(r) as edge_properties
|
| 1270 |
LIMIT 1
|
| 1271 |
$$) AS (edge_properties agtype)""" % (
|
|
|
|
| 1284 |
Retrieves all edges (relationships) for a particular node identified by its label.
|
| 1285 |
:return: list of dictionaries containing edge information
|
| 1286 |
"""
|
| 1287 |
+
label = source_node_id.strip('"')
|
| 1288 |
|
| 1289 |
query = """SELECT * FROM cypher('%s', $$
|
| 1290 |
+
MATCH (n:base {entity_id: "%s"})
|
| 1291 |
+
OPTIONAL MATCH (n)-[]-(connected:base)
|
| 1292 |
RETURN n, connected
|
| 1293 |
$$) AS (n agtype, connected agtype)""" % (
|
| 1294 |
self.graph_name,
|
|
|
|
| 1301 |
source_node = record["n"] if record["n"] else None
|
| 1302 |
connected_node = record["connected"] if record["connected"] else None
|
| 1303 |
|
| 1304 |
+
if (
|
| 1305 |
+
source_node
|
| 1306 |
+
and connected_node
|
| 1307 |
+
and "properties" in source_node
|
| 1308 |
+
and "properties" in connected_node
|
| 1309 |
+
):
|
| 1310 |
+
source_label = source_node["properties"].get("entity_id")
|
| 1311 |
+
target_label = connected_node["properties"].get("entity_id")
|
|
|
|
|
|
|
| 1312 |
|
| 1313 |
+
if source_label and target_label:
|
| 1314 |
+
edges.append((source_label, target_label))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1315 |
|
| 1316 |
return edges
|
| 1317 |
|
|
|
|
| 1321 |
retry=retry_if_exception_type((PGGraphQueryException,)),
|
| 1322 |
)
|
| 1323 |
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
| 1324 |
+
"""
|
| 1325 |
+
Upsert a node in the Neo4j database.
|
| 1326 |
+
|
| 1327 |
+
Args:
|
| 1328 |
+
node_id: The unique identifier for the node (used as label)
|
| 1329 |
+
node_data: Dictionary of node properties
|
| 1330 |
+
"""
|
| 1331 |
+
if "entity_id" not in node_data:
|
| 1332 |
+
raise ValueError(
|
| 1333 |
+
"PostgreSQL: node properties must contain an 'entity_id' field"
|
| 1334 |
+
)
|
| 1335 |
+
|
| 1336 |
+
label = node_id.strip('"')
|
| 1337 |
+
properties = self._format_properties(node_data)
|
| 1338 |
|
| 1339 |
query = """SELECT * FROM cypher('%s', $$
|
| 1340 |
+
MERGE (n:base {entity_id: "%s"})
|
| 1341 |
SET n += %s
|
| 1342 |
RETURN n
|
| 1343 |
$$) AS (n agtype)""" % (
|
| 1344 |
self.graph_name,
|
| 1345 |
label,
|
| 1346 |
+
properties,
|
| 1347 |
)
|
| 1348 |
|
| 1349 |
try:
|
| 1350 |
await self._query(query, readonly=False, upsert=True)
|
| 1351 |
|
| 1352 |
+
except Exception:
|
| 1353 |
+
logger.error(f"POSTGRES, upsert_node error on node_id: `{node_id}`")
|
| 1354 |
raise
|
| 1355 |
|
| 1356 |
@retry(
|
|
|
|
| 1369 |
target_node_id (str): Label of the target node (used as identifier)
|
| 1370 |
edge_data (dict): dictionary of properties to set on the edge
|
| 1371 |
"""
|
| 1372 |
+
src_label = source_node_id.strip('"')
|
| 1373 |
+
tgt_label = target_node_id.strip('"')
|
| 1374 |
+
edge_properties = self._format_properties(edge_data)
|
| 1375 |
|
| 1376 |
query = """SELECT * FROM cypher('%s', $$
|
| 1377 |
+
MATCH (source:base {entity_id: "%s"})
|
| 1378 |
WITH source
|
| 1379 |
+
MATCH (target:base {entity_id: "%s"})
|
| 1380 |
MERGE (source)-[r:DIRECTED]->(target)
|
| 1381 |
SET r += %s
|
| 1382 |
RETURN r
|
|
|
|
| 1384 |
self.graph_name,
|
| 1385 |
src_label,
|
| 1386 |
tgt_label,
|
| 1387 |
+
edge_properties,
|
| 1388 |
)
|
| 1389 |
|
| 1390 |
try:
|
| 1391 |
await self._query(query, readonly=False, upsert=True)
|
| 1392 |
|
| 1393 |
+
except Exception:
|
| 1394 |
+
logger.error(
|
| 1395 |
+
f"POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`"
|
| 1396 |
+
)
|
| 1397 |
raise
|
| 1398 |
|
| 1399 |
async def _node2vec_embed(self):
|
|
|
|
| 1406 |
Args:
|
| 1407 |
node_id (str): The ID of the node to delete.
|
| 1408 |
"""
|
| 1409 |
+
label = node_id.strip('"')
|
| 1410 |
|
| 1411 |
query = """SELECT * FROM cypher('%s', $$
|
| 1412 |
+
MATCH (n:base {entity_id: "%s"})
|
| 1413 |
DETACH DELETE n
|
| 1414 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
| 1415 |
|
|
|
|
| 1426 |
Args:
|
| 1427 |
node_ids (list[str]): A list of node IDs to remove.
|
| 1428 |
"""
|
| 1429 |
+
node_ids = [node_id.strip('"') for node_id in node_ids]
|
| 1430 |
+
node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
|
|
|
|
|
|
|
| 1431 |
|
| 1432 |
query = """SELECT * FROM cypher('%s', $$
|
| 1433 |
+
MATCH (n:base)
|
| 1434 |
+
WHERE n.entity_id IN [%s]
|
| 1435 |
DETACH DELETE n
|
| 1436 |
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
| 1437 |
|
|
|
|
| 1448 |
Args:
|
| 1449 |
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
| 1450 |
"""
|
| 1451 |
+
for source, target in edges:
|
| 1452 |
+
src_label = source.strip('"')
|
| 1453 |
+
tgt_label = target.strip('"')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1454 |
|
| 1455 |
+
query = """SELECT * FROM cypher('%s', $$
|
| 1456 |
+
MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
|
| 1457 |
+
DELETE r
|
| 1458 |
+
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
|
|
|
|
| 1459 |
|
| 1460 |
+
try:
|
| 1461 |
+
await self._query(query, readonly=False)
|
| 1462 |
+
logger.debug(f"Deleted edge from '{source}' to '{target}'")
|
| 1463 |
+
except Exception as e:
|
| 1464 |
+
logger.error(f"Error during edge deletion: {str(e)}")
|
| 1465 |
+
raise
|
| 1466 |
|
| 1467 |
async def get_all_labels(self) -> list[str]:
|
| 1468 |
"""
|
|
|
|
| 1473 |
"""
|
| 1474 |
query = (
|
| 1475 |
"""SELECT * FROM cypher('%s', $$
|
| 1476 |
+
MATCH (n:base)
|
| 1477 |
+
WHERE n.entity_id IS NOT NULL
|
| 1478 |
+
RETURN DISTINCT n.entity_id AS label
|
| 1479 |
+
ORDER BY n.entity_id
|
| 1480 |
$$) AS (label text)"""
|
| 1481 |
% self.graph_name
|
| 1482 |
)
|
| 1483 |
|
| 1484 |
results = await self._query(query)
|
| 1485 |
+
labels = [result["label"] for result in results]
|
|
|
|
| 1486 |
return labels
|
| 1487 |
|
| 1488 |
async def embed_nodes(
|
|
|
|
| 1504 |
return await embed_func()
|
| 1505 |
|
| 1506 |
async def get_knowledge_graph(
|
| 1507 |
+
self,
|
| 1508 |
+
node_label: str,
|
| 1509 |
+
max_depth: int = 3,
|
| 1510 |
+
max_nodes: int = MAX_GRAPH_NODES,
|
| 1511 |
) -> KnowledgeGraph:
|
| 1512 |
"""
|
| 1513 |
+
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
| 1514 |
|
| 1515 |
Args:
|
| 1516 |
+
node_label: Label of the starting node, * means all nodes
|
| 1517 |
+
max_depth: Maximum depth of the subgraph, Defaults to 3
|
| 1518 |
+
max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed)
|
| 1519 |
|
| 1520 |
Returns:
|
| 1521 |
+
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
| 1522 |
+
indicating whether the graph was truncated due to max_nodes limit
|
| 1523 |
"""
|
| 1524 |
+
# First, count the total number of nodes that would be returned without limit
|
| 1525 |
+
if node_label == "*":
|
| 1526 |
+
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1527 |
+
MATCH (n:base)
|
| 1528 |
+
RETURN count(distinct n) AS total_nodes
|
| 1529 |
+
$$) AS (total_nodes bigint)"""
|
| 1530 |
+
else:
|
| 1531 |
+
strip_label = node_label.strip('"')
|
| 1532 |
+
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1533 |
+
MATCH (n:base {{entity_id: "{strip_label}"}})
|
| 1534 |
+
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
| 1535 |
+
RETURN count(distinct m) AS total_nodes
|
| 1536 |
+
$$) AS (total_nodes bigint)"""
|
| 1537 |
+
|
| 1538 |
+
count_result = await self._query(count_query)
|
| 1539 |
+
total_nodes = count_result[0]["total_nodes"] if count_result else 0
|
| 1540 |
+
is_truncated = total_nodes > max_nodes
|
| 1541 |
+
|
| 1542 |
+
# Now get the actual data with limit
|
| 1543 |
if node_label == "*":
|
| 1544 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1545 |
+
MATCH (n:base)
|
| 1546 |
+
OPTIONAL MATCH (n)-[r]->(target:base)
|
| 1547 |
+
RETURN collect(distinct n) AS n, collect(distinct r) AS r
|
| 1548 |
+
LIMIT {max_nodes}
|
| 1549 |
+
$$) AS (n agtype, r agtype)"""
|
| 1550 |
else:
|
| 1551 |
+
strip_label = node_label.strip('"')
|
| 1552 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1553 |
+
MATCH (n:base {{entity_id: "{strip_label}"}})
|
| 1554 |
+
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
| 1555 |
+
RETURN nodes(p) AS n, relationships(p) AS r
|
| 1556 |
+
LIMIT {max_nodes}
|
| 1557 |
+
$$) AS (n agtype, r agtype)"""
|
| 1558 |
|
| 1559 |
results = await self._query(query)
|
| 1560 |
|
| 1561 |
+
# Process the query results with deduplication by node and edge IDs
|
| 1562 |
+
nodes_dict = {}
|
| 1563 |
+
edges_dict = {}
|
| 1564 |
+
for result in results:
|
| 1565 |
+
# Handle single node cases
|
| 1566 |
+
if result.get("n") and isinstance(result["n"], dict):
|
| 1567 |
+
node_id = str(result["n"]["id"])
|
| 1568 |
+
if node_id not in nodes_dict:
|
| 1569 |
+
nodes_dict[node_id] = KnowledgeGraphNode(
|
| 1570 |
+
id=node_id,
|
| 1571 |
+
labels=[result["n"]["properties"]["entity_id"]],
|
| 1572 |
+
properties=result["n"]["properties"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1573 |
)
|
| 1574 |
+
# Handle node list cases
|
| 1575 |
+
elif result.get("n") and isinstance(result["n"], list):
|
| 1576 |
+
for node in result["n"]:
|
| 1577 |
+
if isinstance(node, dict) and "id" in node:
|
| 1578 |
+
node_id = str(node["id"])
|
| 1579 |
+
if node_id not in nodes_dict and "properties" in node:
|
| 1580 |
+
nodes_dict[node_id] = KnowledgeGraphNode(
|
| 1581 |
+
id=node_id,
|
| 1582 |
+
labels=[node["properties"]["entity_id"]],
|
| 1583 |
+
properties=node["properties"],
|
| 1584 |
+
)
|
| 1585 |
|
| 1586 |
+
# Handle single edge cases
|
| 1587 |
+
if result.get("r") and isinstance(result["r"], dict):
|
| 1588 |
+
edge_id = str(result["r"]["id"])
|
| 1589 |
+
if edge_id not in edges_dict:
|
| 1590 |
+
edges_dict[edge_id] = KnowledgeGraphEdge(
|
| 1591 |
+
id=edge_id,
|
| 1592 |
+
type="DIRECTED",
|
| 1593 |
+
source=str(result["r"]["start_id"]),
|
| 1594 |
+
target=str(result["r"]["end_id"]),
|
| 1595 |
+
properties=result["r"]["properties"],
|
| 1596 |
+
)
|
| 1597 |
+
# Handle edge list cases
|
| 1598 |
+
elif result.get("r") and isinstance(result["r"], list):
|
| 1599 |
+
for edge in result["r"]:
|
| 1600 |
+
if isinstance(edge, dict) and "id" in edge:
|
| 1601 |
+
edge_id = str(edge["id"])
|
| 1602 |
+
if edge_id not in edges_dict:
|
| 1603 |
+
edges_dict[edge_id] = KnowledgeGraphEdge(
|
| 1604 |
+
id=edge_id,
|
| 1605 |
+
type="DIRECTED",
|
| 1606 |
+
source=str(edge["start_id"]),
|
| 1607 |
+
target=str(edge["end_id"]),
|
| 1608 |
+
properties=edge["properties"],
|
| 1609 |
+
)
|
| 1610 |
|
| 1611 |
+
# Construct and return the KnowledgeGraph with deduplicated nodes and edges
|
| 1612 |
kg = KnowledgeGraph(
|
| 1613 |
+
nodes=list(nodes_dict.values()),
|
| 1614 |
+
edges=list(edges_dict.values()),
|
| 1615 |
+
is_truncated=is_truncated,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1616 |
)
|
| 1617 |
|
| 1618 |
+
logger.info(
|
| 1619 |
+
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
|
| 1620 |
+
)
|
| 1621 |
return kg
|
| 1622 |
|
| 1623 |
+
async def drop(self) -> dict[str, str]:
|
| 1624 |
"""Drop the storage"""
|
| 1625 |
+
try:
|
| 1626 |
+
drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
| 1627 |
+
MATCH (n)
|
| 1628 |
+
DETACH DELETE n
|
| 1629 |
+
$$) AS (result agtype)"""
|
| 1630 |
+
|
| 1631 |
+
await self._query(drop_query, readonly=False)
|
| 1632 |
+
return {"status": "success", "message": "graph data dropped"}
|
| 1633 |
+
except Exception as e:
|
| 1634 |
+
logger.error(f"Error dropping graph: {e}")
|
| 1635 |
+
return {"status": "error", "message": str(e)}
|
| 1636 |
|
| 1637 |
|
| 1638 |
NAMESPACE_TABLE_MAP = {
|
|
|
|
| 1745 |
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
|
| 1746 |
""",
|
| 1747 |
"get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
| 1748 |
+
chunk_order_index, full_doc_id, file_path
|
| 1749 |
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
|
| 1750 |
""",
|
| 1751 |
"get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
|
|
|
|
| 1758 |
FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
|
| 1759 |
""",
|
| 1760 |
"get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
|
| 1761 |
+
chunk_order_index, full_doc_id, file_path
|
| 1762 |
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
|
| 1763 |
""",
|
| 1764 |
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
|
|
|
|
| 1790 |
file_path=EXCLUDED.file_path,
|
| 1791 |
update_time = CURRENT_TIMESTAMP
|
| 1792 |
""",
|
| 1793 |
+
# SQL for VectorStorage
|
| 1794 |
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
|
| 1795 |
content_vector, chunk_ids, file_path)
|
| 1796 |
VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
|
|
|
|
| 1814 |
file_path=EXCLUDED.file_path,
|
| 1815 |
update_time = CURRENT_TIMESTAMP
|
| 1816 |
""",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1817 |
"relationships": """
|
| 1818 |
WITH relevant_chunks AS (
|
| 1819 |
SELECT id as chunk_id
|
|
|
|
| 1854 |
FROM LIGHTRAG_DOC_CHUNKS
|
| 1855 |
WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
|
| 1856 |
)
|
| 1857 |
+
SELECT id, content, file_path FROM
|
| 1858 |
(
|
| 1859 |
+
SELECT id, content, file_path, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
| 1860 |
FROM LIGHTRAG_DOC_CHUNKS
|
| 1861 |
where workspace=$1
|
| 1862 |
AND id IN (SELECT chunk_id FROM relevant_chunks)
|
|
|
|
| 1865 |
ORDER BY distance DESC
|
| 1866 |
LIMIT $3
|
| 1867 |
""",
|
| 1868 |
+
# DROP tables
|
| 1869 |
+
"drop_specifiy_table_workspace": """
|
| 1870 |
+
DELETE FROM {table_name} WHERE workspace=$1
|
| 1871 |
+
""",
|
| 1872 |
}
|
lightrag/kg/qdrant_impl.py
CHANGED
|
@@ -8,17 +8,15 @@ import uuid
|
|
| 8 |
from ..utils import logger
|
| 9 |
from ..base import BaseVectorStorage
|
| 10 |
import configparser
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
config = configparser.ConfigParser()
|
| 14 |
-
config.read("config.ini", "utf-8")
|
| 15 |
-
|
| 16 |
import pipmaster as pm
|
| 17 |
|
| 18 |
if not pm.is_installed("qdrant-client"):
|
| 19 |
pm.install("qdrant-client")
|
| 20 |
|
| 21 |
-
from qdrant_client import QdrantClient, models
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def compute_mdhash_id_for_qdrant(
|
|
@@ -275,3 +273,92 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|
| 275 |
except Exception as e:
|
| 276 |
logger.error(f"Error searching for prefix '{prefix}': {e}")
|
| 277 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from ..utils import logger
|
| 9 |
from ..base import BaseVectorStorage
|
| 10 |
import configparser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
import pipmaster as pm
|
| 12 |
|
| 13 |
if not pm.is_installed("qdrant-client"):
|
| 14 |
pm.install("qdrant-client")
|
| 15 |
|
| 16 |
+
from qdrant_client import QdrantClient, models # type: ignore
|
| 17 |
+
|
| 18 |
+
config = configparser.ConfigParser()
|
| 19 |
+
config.read("config.ini", "utf-8")
|
| 20 |
|
| 21 |
|
| 22 |
def compute_mdhash_id_for_qdrant(
|
|
|
|
| 273 |
except Exception as e:
|
| 274 |
logger.error(f"Error searching for prefix '{prefix}': {e}")
|
| 275 |
return []
|
| 276 |
+
|
| 277 |
+
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
| 278 |
+
"""Get vector data by its ID
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
id: The unique identifier of the vector
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
The vector data if found, or None if not found
|
| 285 |
+
"""
|
| 286 |
+
try:
|
| 287 |
+
# Convert to Qdrant compatible ID
|
| 288 |
+
qdrant_id = compute_mdhash_id_for_qdrant(id)
|
| 289 |
+
|
| 290 |
+
# Retrieve the point by ID
|
| 291 |
+
result = self._client.retrieve(
|
| 292 |
+
collection_name=self.namespace,
|
| 293 |
+
ids=[qdrant_id],
|
| 294 |
+
with_payload=True,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if not result:
|
| 298 |
+
return None
|
| 299 |
+
|
| 300 |
+
return result[0].payload
|
| 301 |
+
except Exception as e:
|
| 302 |
+
logger.error(f"Error retrieving vector data for ID {id}: {e}")
|
| 303 |
+
return None
|
| 304 |
+
|
| 305 |
+
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
| 306 |
+
"""Get multiple vector data by their IDs
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
ids: List of unique identifiers
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
List of vector data objects that were found
|
| 313 |
+
"""
|
| 314 |
+
if not ids:
|
| 315 |
+
return []
|
| 316 |
+
|
| 317 |
+
try:
|
| 318 |
+
# Convert to Qdrant compatible IDs
|
| 319 |
+
qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
|
| 320 |
+
|
| 321 |
+
# Retrieve the points by IDs
|
| 322 |
+
results = self._client.retrieve(
|
| 323 |
+
collection_name=self.namespace,
|
| 324 |
+
ids=qdrant_ids,
|
| 325 |
+
with_payload=True,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
return [point.payload for point in results]
|
| 329 |
+
except Exception as e:
|
| 330 |
+
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
|
| 331 |
+
return []
|
| 332 |
+
|
| 333 |
+
async def drop(self) -> dict[str, str]:
|
| 334 |
+
"""Drop all vector data from storage and clean up resources
|
| 335 |
+
|
| 336 |
+
This method will delete all data from the Qdrant collection.
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
dict[str, str]: Operation status and message
|
| 340 |
+
- On success: {"status": "success", "message": "data dropped"}
|
| 341 |
+
- On failure: {"status": "error", "message": "<error details>"}
|
| 342 |
+
"""
|
| 343 |
+
try:
|
| 344 |
+
# Delete the collection and recreate it
|
| 345 |
+
if self._client.collection_exists(self.namespace):
|
| 346 |
+
self._client.delete_collection(self.namespace)
|
| 347 |
+
|
| 348 |
+
# Recreate the collection
|
| 349 |
+
QdrantVectorDBStorage.create_collection_if_not_exist(
|
| 350 |
+
self._client,
|
| 351 |
+
self.namespace,
|
| 352 |
+
vectors_config=models.VectorParams(
|
| 353 |
+
size=self.embedding_func.embedding_dim,
|
| 354 |
+
distance=models.Distance.COSINE,
|
| 355 |
+
),
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
logger.info(
|
| 359 |
+
f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
|
| 360 |
+
)
|
| 361 |
+
return {"status": "success", "message": "data dropped"}
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
|
| 364 |
+
return {"status": "error", "message": str(e)}
|
lightrag/kg/redis_impl.py
CHANGED
|
@@ -12,6 +12,7 @@ if not pm.is_installed("redis"):
|
|
| 12 |
from redis.asyncio import Redis, ConnectionPool
|
| 13 |
from redis.exceptions import RedisError, ConnectionError
|
| 14 |
from lightrag.utils import logger, compute_mdhash_id
|
|
|
|
| 15 |
from lightrag.base import BaseKVStorage
|
| 16 |
import json
|
| 17 |
|
|
@@ -121,7 +122,11 @@ class RedisKVStorage(BaseKVStorage):
|
|
| 121 |
except json.JSONEncodeError as e:
|
| 122 |
logger.error(f"JSON encode error during upsert: {e}")
|
| 123 |
raise
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
async def delete(self, ids: list[str]) -> None:
|
| 126 |
"""Delete entries with specified IDs"""
|
| 127 |
if not ids:
|
|
@@ -138,71 +143,52 @@ class RedisKVStorage(BaseKVStorage):
|
|
| 138 |
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
|
| 139 |
)
|
| 140 |
|
| 141 |
-
async def
|
| 142 |
-
"""Delete
|
| 143 |
-
try:
|
| 144 |
-
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
| 145 |
-
logger.debug(
|
| 146 |
-
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
| 147 |
-
)
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 160 |
-
"""Delete all relations associated with an entity"""
|
| 161 |
try:
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
pipe = redis.pipeline()
|
| 172 |
for key in keys:
|
| 173 |
-
pipe.
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
data = json.loads(value)
|
| 180 |
-
if (
|
| 181 |
-
data.get("src_id") == entity_name
|
| 182 |
-
or data.get("tgt_id") == entity_name
|
| 183 |
-
):
|
| 184 |
-
relation_keys.append(key)
|
| 185 |
-
except json.JSONDecodeError:
|
| 186 |
-
logger.warning(f"Invalid JSON in key {key}")
|
| 187 |
-
continue
|
| 188 |
-
|
| 189 |
-
if cursor == 0:
|
| 190 |
-
break
|
| 191 |
-
|
| 192 |
-
# Delete relations in batches
|
| 193 |
-
if relation_keys:
|
| 194 |
-
# Delete in chunks to avoid too many arguments
|
| 195 |
-
chunk_size = 1000
|
| 196 |
-
for i in range(0, len(relation_keys), chunk_size):
|
| 197 |
-
chunk = relation_keys[i:i + chunk_size]
|
| 198 |
-
deleted = await redis.delete(*chunk)
|
| 199 |
-
logger.debug(f"Deleted {deleted} relations for {entity_name} (batch {i//chunk_size + 1})")
|
| 200 |
else:
|
| 201 |
-
logger.
|
|
|
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
|
|
|
| 205 |
|
| 206 |
-
async def index_done_callback(self) -> None:
|
| 207 |
-
# Redis handles persistence automatically
|
| 208 |
-
pass
|
|
|
|
| 12 |
from redis.asyncio import Redis, ConnectionPool
|
| 13 |
from redis.exceptions import RedisError, ConnectionError
|
| 14 |
from lightrag.utils import logger, compute_mdhash_id
|
| 15 |
+
|
| 16 |
from lightrag.base import BaseKVStorage
|
| 17 |
import json
|
| 18 |
|
|
|
|
| 122 |
except json.JSONEncodeError as e:
|
| 123 |
logger.error(f"JSON encode error during upsert: {e}")
|
| 124 |
raise
|
| 125 |
+
|
| 126 |
+
async def index_done_callback(self) -> None:
|
| 127 |
+
# Redis handles persistence automatically
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
async def delete(self, ids: list[str]) -> None:
|
| 131 |
"""Delete entries with specified IDs"""
|
| 132 |
if not ids:
|
|
|
|
| 143 |
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
|
| 144 |
)
|
| 145 |
|
| 146 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
| 147 |
+
"""Delete specific records from storage by by cache mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
Importance notes for Redis storage:
|
| 150 |
+
1. This will immediately delete the specified cache modes from Redis
|
| 151 |
|
| 152 |
+
Args:
|
| 153 |
+
modes (list[str]): List of cache mode to be drop from storage
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
True: if the cache drop successfully
|
| 157 |
+
False: if the cache drop failed
|
| 158 |
+
"""
|
| 159 |
+
if not modes:
|
| 160 |
+
return False
|
| 161 |
|
|
|
|
|
|
|
| 162 |
try:
|
| 163 |
+
await self.delete(modes)
|
| 164 |
+
return True
|
| 165 |
+
except Exception:
|
| 166 |
+
return False
|
| 167 |
+
|
| 168 |
+
async def drop(self) -> dict[str, str]:
|
| 169 |
+
"""Drop the storage by removing all keys under the current namespace.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
| 173 |
+
"""
|
| 174 |
+
async with self._get_redis_connection() as redis:
|
| 175 |
+
try:
|
| 176 |
+
keys = await redis.keys(f"{self.namespace}:*")
|
| 177 |
+
|
| 178 |
+
if keys:
|
| 179 |
pipe = redis.pipeline()
|
| 180 |
for key in keys:
|
| 181 |
+
pipe.delete(key)
|
| 182 |
+
results = await pipe.execute()
|
| 183 |
+
deleted_count = sum(results)
|
| 184 |
+
|
| 185 |
+
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
|
| 186 |
+
return {"status": "success", "message": f"{deleted_count} keys dropped"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
else:
|
| 188 |
+
logger.info(f"No keys found to drop in {self.namespace}")
|
| 189 |
+
return {"status": "success", "message": "no keys to drop"}
|
| 190 |
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"Error dropping keys from {self.namespace}: {e}")
|
| 193 |
+
return {"status": "error", "message": str(e)}
|
| 194 |
|
|
|
|
|
|
|
|
|
lightrag/kg/tidb_impl.py
CHANGED
|
@@ -20,7 +20,7 @@ if not pm.is_installed("pymysql"):
|
|
| 20 |
if not pm.is_installed("sqlalchemy"):
|
| 21 |
pm.install("sqlalchemy")
|
| 22 |
|
| 23 |
-
from sqlalchemy import create_engine, text
|
| 24 |
|
| 25 |
|
| 26 |
class TiDB:
|
|
@@ -278,6 +278,86 @@ class TiDBKVStorage(BaseKVStorage):
|
|
| 278 |
# Ti handles persistence automatically
|
| 279 |
pass
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
@final
|
| 283 |
@dataclass
|
|
@@ -406,16 +486,91 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
| 406 |
params = {"workspace": self.db.workspace, "status": status}
|
| 407 |
return await self.db.query(SQL, params, multirows=True)
|
| 408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
async def delete_entity(self, entity_name: str) -> None:
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
async def index_done_callback(self) -> None:
|
| 416 |
# Ti handles persistence automatically
|
| 417 |
pass
|
| 418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
| 420 |
"""Search for records with IDs starting with a specific prefix.
|
| 421 |
|
|
@@ -710,6 +865,18 @@ class TiDBGraphStorage(BaseGraphStorage):
|
|
| 710 |
# Ti handles persistence automatically
|
| 711 |
pass
|
| 712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
async def delete_node(self, node_id: str) -> None:
|
| 714 |
"""Delete a node and all its related edges
|
| 715 |
|
|
@@ -1129,4 +1296,6 @@ SQL_TEMPLATES = {
|
|
| 1129 |
FROM LIGHTRAG_DOC_CHUNKS
|
| 1130 |
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
|
| 1131 |
""",
|
|
|
|
|
|
|
| 1132 |
}
|
|
|
|
| 20 |
if not pm.is_installed("sqlalchemy"):
|
| 21 |
pm.install("sqlalchemy")
|
| 22 |
|
| 23 |
+
from sqlalchemy import create_engine, text # type: ignore
|
| 24 |
|
| 25 |
|
| 26 |
class TiDB:
|
|
|
|
| 278 |
# Ti handles persistence automatically
|
| 279 |
pass
|
| 280 |
|
| 281 |
+
async def delete(self, ids: list[str]) -> None:
|
| 282 |
+
"""Delete records with specified IDs from the storage.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
ids: List of record IDs to be deleted
|
| 286 |
+
"""
|
| 287 |
+
if not ids:
|
| 288 |
+
return
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
table_name = namespace_to_table_name(self.namespace)
|
| 292 |
+
id_field = namespace_to_id(self.namespace)
|
| 293 |
+
|
| 294 |
+
if not table_name or not id_field:
|
| 295 |
+
logger.error(f"Unknown namespace for deletion: {self.namespace}")
|
| 296 |
+
return
|
| 297 |
+
|
| 298 |
+
ids_list = ",".join([f"'{id}'" for id in ids])
|
| 299 |
+
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
|
| 300 |
+
|
| 301 |
+
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
|
| 302 |
+
logger.info(
|
| 303 |
+
f"Successfully deleted {len(ids)} records from {self.namespace}"
|
| 304 |
+
)
|
| 305 |
+
except Exception as e:
|
| 306 |
+
logger.error(f"Error deleting records from {self.namespace}: {e}")
|
| 307 |
+
|
| 308 |
+
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
|
| 309 |
+
"""Delete specific records from storage by cache mode
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
modes (list[str]): List of cache modes to be dropped from storage
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
bool: True if successful, False otherwise
|
| 316 |
+
"""
|
| 317 |
+
if not modes:
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
try:
|
| 321 |
+
table_name = namespace_to_table_name(self.namespace)
|
| 322 |
+
if not table_name:
|
| 323 |
+
return False
|
| 324 |
+
|
| 325 |
+
if table_name != "LIGHTRAG_LLM_CACHE":
|
| 326 |
+
return False
|
| 327 |
+
|
| 328 |
+
# 构建MySQL风格的IN查询
|
| 329 |
+
modes_list = ", ".join([f"'{mode}'" for mode in modes])
|
| 330 |
+
sql = f"""
|
| 331 |
+
DELETE FROM {table_name}
|
| 332 |
+
WHERE workspace = :workspace
|
| 333 |
+
AND mode IN ({modes_list})
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
logger.info(f"Deleting cache by modes: {modes}")
|
| 337 |
+
await self.db.execute(sql, {"workspace": self.db.workspace})
|
| 338 |
+
return True
|
| 339 |
+
except Exception as e:
|
| 340 |
+
logger.error(f"Error deleting cache by modes {modes}: {e}")
|
| 341 |
+
return False
|
| 342 |
+
|
| 343 |
+
async def drop(self) -> dict[str, str]:
|
| 344 |
+
"""Drop the storage"""
|
| 345 |
+
try:
|
| 346 |
+
table_name = namespace_to_table_name(self.namespace)
|
| 347 |
+
if not table_name:
|
| 348 |
+
return {
|
| 349 |
+
"status": "error",
|
| 350 |
+
"message": f"Unknown namespace: {self.namespace}",
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
| 354 |
+
table_name=table_name
|
| 355 |
+
)
|
| 356 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
| 357 |
+
return {"status": "success", "message": "data dropped"}
|
| 358 |
+
except Exception as e:
|
| 359 |
+
return {"status": "error", "message": str(e)}
|
| 360 |
+
|
| 361 |
|
| 362 |
@final
|
| 363 |
@dataclass
|
|
|
|
| 486 |
params = {"workspace": self.db.workspace, "status": status}
|
| 487 |
return await self.db.query(SQL, params, multirows=True)
|
| 488 |
|
| 489 |
+
async def delete(self, ids: list[str]) -> None:
|
| 490 |
+
"""Delete vectors with specified IDs from the storage.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
ids: List of vector IDs to be deleted
|
| 494 |
+
"""
|
| 495 |
+
if not ids:
|
| 496 |
+
return
|
| 497 |
+
|
| 498 |
+
table_name = namespace_to_table_name(self.namespace)
|
| 499 |
+
id_field = namespace_to_id(self.namespace)
|
| 500 |
+
|
| 501 |
+
if not table_name or not id_field:
|
| 502 |
+
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
|
| 503 |
+
return
|
| 504 |
+
|
| 505 |
+
ids_list = ",".join([f"'{id}'" for id in ids])
|
| 506 |
+
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
|
| 507 |
+
|
| 508 |
+
try:
|
| 509 |
+
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
|
| 510 |
+
logger.debug(
|
| 511 |
+
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
| 512 |
+
)
|
| 513 |
+
except Exception as e:
|
| 514 |
+
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
| 515 |
+
|
| 516 |
async def delete_entity(self, entity_name: str) -> None:
|
| 517 |
+
"""Delete an entity by its name from the vector storage.
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
entity_name: The name of the entity to delete
|
| 521 |
+
"""
|
| 522 |
+
try:
|
| 523 |
+
# Construct SQL to delete the entity
|
| 524 |
+
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES
|
| 525 |
+
WHERE workspace = :workspace AND name = :entity_name"""
|
| 526 |
+
|
| 527 |
+
await self.db.execute(
|
| 528 |
+
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
|
| 529 |
+
)
|
| 530 |
+
logger.debug(f"Successfully deleted entity {entity_name}")
|
| 531 |
+
except Exception as e:
|
| 532 |
+
logger.error(f"Error deleting entity {entity_name}: {e}")
|
| 533 |
|
| 534 |
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 535 |
+
"""Delete all relations associated with an entity.
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
entity_name: The name of the entity whose relations should be deleted
|
| 539 |
+
"""
|
| 540 |
+
try:
|
| 541 |
+
# Delete relations where the entity is either the source or target
|
| 542 |
+
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES
|
| 543 |
+
WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)"""
|
| 544 |
+
|
| 545 |
+
await self.db.execute(
|
| 546 |
+
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
|
| 547 |
+
)
|
| 548 |
+
logger.debug(f"Successfully deleted relations for entity {entity_name}")
|
| 549 |
+
except Exception as e:
|
| 550 |
+
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
|
| 551 |
|
| 552 |
async def index_done_callback(self) -> None:
|
| 553 |
# Ti handles persistence automatically
|
| 554 |
pass
|
| 555 |
|
| 556 |
+
async def drop(self) -> dict[str, str]:
|
| 557 |
+
"""Drop the storage"""
|
| 558 |
+
try:
|
| 559 |
+
table_name = namespace_to_table_name(self.namespace)
|
| 560 |
+
if not table_name:
|
| 561 |
+
return {
|
| 562 |
+
"status": "error",
|
| 563 |
+
"message": f"Unknown namespace: {self.namespace}",
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
| 567 |
+
table_name=table_name
|
| 568 |
+
)
|
| 569 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
| 570 |
+
return {"status": "success", "message": "data dropped"}
|
| 571 |
+
except Exception as e:
|
| 572 |
+
return {"status": "error", "message": str(e)}
|
| 573 |
+
|
| 574 |
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
|
| 575 |
"""Search for records with IDs starting with a specific prefix.
|
| 576 |
|
|
|
|
| 865 |
# Ti handles persistence automatically
|
| 866 |
pass
|
| 867 |
|
| 868 |
+
async def drop(self) -> dict[str, str]:
|
| 869 |
+
"""Drop the storage"""
|
| 870 |
+
try:
|
| 871 |
+
drop_sql = """
|
| 872 |
+
DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace;
|
| 873 |
+
DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace;
|
| 874 |
+
"""
|
| 875 |
+
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
|
| 876 |
+
return {"status": "success", "message": "graph data dropped"}
|
| 877 |
+
except Exception as e:
|
| 878 |
+
return {"status": "error", "message": str(e)}
|
| 879 |
+
|
| 880 |
async def delete_node(self, node_id: str) -> None:
|
| 881 |
"""Delete a node and all its related edges
|
| 882 |
|
|
|
|
| 1296 |
FROM LIGHTRAG_DOC_CHUNKS
|
| 1297 |
WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
|
| 1298 |
""",
|
| 1299 |
+
# Drop tables
|
| 1300 |
+
"drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace",
|
| 1301 |
}
|
lightrag/lightrag.py
CHANGED
|
@@ -13,7 +13,6 @@ import pandas as pd
|
|
| 13 |
|
| 14 |
|
| 15 |
from lightrag.kg import (
|
| 16 |
-
STORAGE_ENV_REQUIREMENTS,
|
| 17 |
STORAGES,
|
| 18 |
verify_storage_implementation,
|
| 19 |
)
|
|
@@ -230,6 +229,7 @@ class LightRAG:
|
|
| 230 |
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
| 231 |
"""Additional parameters for vector database storage."""
|
| 232 |
|
|
|
|
| 233 |
namespace_prefix: str = field(default="")
|
| 234 |
"""Prefix for namespacing stored data across different environments."""
|
| 235 |
|
|
@@ -510,36 +510,22 @@ class LightRAG:
|
|
| 510 |
self,
|
| 511 |
node_label: str,
|
| 512 |
max_depth: int = 3,
|
| 513 |
-
|
| 514 |
-
inclusive: bool = False,
|
| 515 |
) -> KnowledgeGraph:
|
| 516 |
"""Get knowledge graph for a given label
|
| 517 |
|
| 518 |
Args:
|
| 519 |
node_label (str): Label to get knowledge graph for
|
| 520 |
max_depth (int): Maximum depth of graph
|
| 521 |
-
|
| 522 |
-
inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
|
| 523 |
|
| 524 |
Returns:
|
| 525 |
KnowledgeGraph: Knowledge graph containing nodes and edges
|
| 526 |
"""
|
| 527 |
-
# get params supported by get_knowledge_graph of specified storage
|
| 528 |
-
import inspect
|
| 529 |
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
)
|
| 533 |
-
|
| 534 |
-
kwargs = {"node_label": node_label, "max_depth": max_depth}
|
| 535 |
-
|
| 536 |
-
if "min_degree" in storage_params and min_degree > 0:
|
| 537 |
-
kwargs["min_degree"] = min_degree
|
| 538 |
-
|
| 539 |
-
if "inclusive" in storage_params:
|
| 540 |
-
kwargs["inclusive"] = inclusive
|
| 541 |
-
|
| 542 |
-
return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
|
| 543 |
|
| 544 |
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
| 545 |
import_path = STORAGES[storage_name]
|
|
@@ -1449,6 +1435,7 @@ class LightRAG:
|
|
| 1449 |
loop = always_get_an_event_loop()
|
| 1450 |
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
| 1451 |
|
|
|
|
| 1452 |
async def adelete_by_entity(self, entity_name: str) -> None:
|
| 1453 |
try:
|
| 1454 |
await self.entities_vdb.delete_entity(entity_name)
|
|
@@ -1486,6 +1473,7 @@ class LightRAG:
|
|
| 1486 |
self.adelete_by_relation(source_entity, target_entity)
|
| 1487 |
)
|
| 1488 |
|
|
|
|
| 1489 |
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
|
| 1490 |
"""Asynchronously delete a relation between two entities.
|
| 1491 |
|
|
@@ -1494,6 +1482,7 @@ class LightRAG:
|
|
| 1494 |
target_entity: Name of the target entity
|
| 1495 |
"""
|
| 1496 |
try:
|
|
|
|
| 1497 |
# Check if the relation exists
|
| 1498 |
edge_exists = await self.chunk_entity_relation_graph.has_edge(
|
| 1499 |
source_entity, target_entity
|
|
@@ -1554,6 +1543,7 @@ class LightRAG:
|
|
| 1554 |
"""
|
| 1555 |
return await self.doc_status.get_docs_by_status(status)
|
| 1556 |
|
|
|
|
| 1557 |
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
| 1558 |
"""Delete a document and all its related data
|
| 1559 |
|
|
@@ -1586,6 +1576,8 @@ class LightRAG:
|
|
| 1586 |
chunk_ids = set(related_chunks.keys())
|
| 1587 |
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
| 1588 |
|
|
|
|
|
|
|
| 1589 |
# 3. Before deleting, check the related entities and relationships for these chunks
|
| 1590 |
for chunk_id in chunk_ids:
|
| 1591 |
# Check entities
|
|
@@ -1857,24 +1849,6 @@ class LightRAG:
|
|
| 1857 |
|
| 1858 |
return result
|
| 1859 |
|
| 1860 |
-
def check_storage_env_vars(self, storage_name: str) -> None:
|
| 1861 |
-
"""Check if all required environment variables for storage implementation exist
|
| 1862 |
-
|
| 1863 |
-
Args:
|
| 1864 |
-
storage_name: Storage implementation name
|
| 1865 |
-
|
| 1866 |
-
Raises:
|
| 1867 |
-
ValueError: If required environment variables are missing
|
| 1868 |
-
"""
|
| 1869 |
-
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
| 1870 |
-
missing_vars = [var for var in required_vars if var not in os.environ]
|
| 1871 |
-
|
| 1872 |
-
if missing_vars:
|
| 1873 |
-
raise ValueError(
|
| 1874 |
-
f"Storage implementation '{storage_name}' requires the following "
|
| 1875 |
-
f"environment variables: {', '.join(missing_vars)}"
|
| 1876 |
-
)
|
| 1877 |
-
|
| 1878 |
async def aclear_cache(self, modes: list[str] | None = None) -> None:
|
| 1879 |
"""Clear cache data from the LLM response cache storage.
|
| 1880 |
|
|
@@ -1906,12 +1880,18 @@ class LightRAG:
|
|
| 1906 |
try:
|
| 1907 |
# Reset the cache storage for specified mode
|
| 1908 |
if modes:
|
| 1909 |
-
await self.llm_response_cache.
|
| 1910 |
-
|
|
|
|
|
|
|
|
|
|
| 1911 |
else:
|
| 1912 |
# Clear all modes
|
| 1913 |
-
await self.llm_response_cache.
|
| 1914 |
-
|
|
|
|
|
|
|
|
|
|
| 1915 |
|
| 1916 |
await self.llm_response_cache.index_done_callback()
|
| 1917 |
|
|
@@ -1922,6 +1902,7 @@ class LightRAG:
|
|
| 1922 |
"""Synchronous version of aclear_cache."""
|
| 1923 |
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
|
| 1924 |
|
|
|
|
| 1925 |
async def aedit_entity(
|
| 1926 |
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
|
| 1927 |
) -> dict[str, Any]:
|
|
@@ -2134,6 +2115,7 @@ class LightRAG:
|
|
| 2134 |
]
|
| 2135 |
)
|
| 2136 |
|
|
|
|
| 2137 |
async def aedit_relation(
|
| 2138 |
self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
|
| 2139 |
) -> dict[str, Any]:
|
|
@@ -2448,6 +2430,7 @@ class LightRAG:
|
|
| 2448 |
self.acreate_relation(source_entity, target_entity, relation_data)
|
| 2449 |
)
|
| 2450 |
|
|
|
|
| 2451 |
async def amerge_entities(
|
| 2452 |
self,
|
| 2453 |
source_entities: list[str],
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
from lightrag.kg import (
|
|
|
|
| 16 |
STORAGES,
|
| 17 |
verify_storage_implementation,
|
| 18 |
)
|
|
|
|
| 229 |
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
|
| 230 |
"""Additional parameters for vector database storage."""
|
| 231 |
|
| 232 |
+
# TODO:deprecated, remove in the future, use WORKSPACE instead
|
| 233 |
namespace_prefix: str = field(default="")
|
| 234 |
"""Prefix for namespacing stored data across different environments."""
|
| 235 |
|
|
|
|
| 510 |
self,
|
| 511 |
node_label: str,
|
| 512 |
max_depth: int = 3,
|
| 513 |
+
max_nodes: int = 1000,
|
|
|
|
| 514 |
) -> KnowledgeGraph:
|
| 515 |
"""Get knowledge graph for a given label
|
| 516 |
|
| 517 |
Args:
|
| 518 |
node_label (str): Label to get knowledge graph for
|
| 519 |
max_depth (int): Maximum depth of graph
|
| 520 |
+
max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000.
|
|
|
|
| 521 |
|
| 522 |
Returns:
|
| 523 |
KnowledgeGraph: Knowledge graph containing nodes and edges
|
| 524 |
"""
|
|
|
|
|
|
|
| 525 |
|
| 526 |
+
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
| 527 |
+
node_label, max_depth, max_nodes
|
| 528 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
|
| 530 |
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
| 531 |
import_path = STORAGES[storage_name]
|
|
|
|
| 1435 |
loop = always_get_an_event_loop()
|
| 1436 |
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
| 1437 |
|
| 1438 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
| 1439 |
async def adelete_by_entity(self, entity_name: str) -> None:
|
| 1440 |
try:
|
| 1441 |
await self.entities_vdb.delete_entity(entity_name)
|
|
|
|
| 1473 |
self.adelete_by_relation(source_entity, target_entity)
|
| 1474 |
)
|
| 1475 |
|
| 1476 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
| 1477 |
async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
|
| 1478 |
"""Asynchronously delete a relation between two entities.
|
| 1479 |
|
|
|
|
| 1482 |
target_entity: Name of the target entity
|
| 1483 |
"""
|
| 1484 |
try:
|
| 1485 |
+
# TODO: check if has_edge function works on reverse relation
|
| 1486 |
# Check if the relation exists
|
| 1487 |
edge_exists = await self.chunk_entity_relation_graph.has_edge(
|
| 1488 |
source_entity, target_entity
|
|
|
|
| 1543 |
"""
|
| 1544 |
return await self.doc_status.get_docs_by_status(status)
|
| 1545 |
|
| 1546 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
| 1547 |
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
| 1548 |
"""Delete a document and all its related data
|
| 1549 |
|
|
|
|
| 1576 |
chunk_ids = set(related_chunks.keys())
|
| 1577 |
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
| 1578 |
|
| 1579 |
+
# TODO: self.entities_vdb.client_storage only works for local storage, need to fix this
|
| 1580 |
+
|
| 1581 |
# 3. Before deleting, check the related entities and relationships for these chunks
|
| 1582 |
for chunk_id in chunk_ids:
|
| 1583 |
# Check entities
|
|
|
|
| 1849 |
|
| 1850 |
return result
|
| 1851 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1852 |
async def aclear_cache(self, modes: list[str] | None = None) -> None:
|
| 1853 |
"""Clear cache data from the LLM response cache storage.
|
| 1854 |
|
|
|
|
| 1880 |
try:
|
| 1881 |
# Reset the cache storage for specified mode
|
| 1882 |
if modes:
|
| 1883 |
+
success = await self.llm_response_cache.drop_cache_by_modes(modes)
|
| 1884 |
+
if success:
|
| 1885 |
+
logger.info(f"Cleared cache for modes: {modes}")
|
| 1886 |
+
else:
|
| 1887 |
+
logger.warning(f"Failed to clear cache for modes: {modes}")
|
| 1888 |
else:
|
| 1889 |
# Clear all modes
|
| 1890 |
+
success = await self.llm_response_cache.drop_cache_by_modes(valid_modes)
|
| 1891 |
+
if success:
|
| 1892 |
+
logger.info("Cleared all cache")
|
| 1893 |
+
else:
|
| 1894 |
+
logger.warning("Failed to clear all cache")
|
| 1895 |
|
| 1896 |
await self.llm_response_cache.index_done_callback()
|
| 1897 |
|
|
|
|
| 1902 |
"""Synchronous version of aclear_cache."""
|
| 1903 |
return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
|
| 1904 |
|
| 1905 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
| 1906 |
async def aedit_entity(
|
| 1907 |
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
|
| 1908 |
) -> dict[str, Any]:
|
|
|
|
| 2115 |
]
|
| 2116 |
)
|
| 2117 |
|
| 2118 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
| 2119 |
async def aedit_relation(
|
| 2120 |
self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
|
| 2121 |
) -> dict[str, Any]:
|
|
|
|
| 2430 |
self.acreate_relation(source_entity, target_entity, relation_data)
|
| 2431 |
)
|
| 2432 |
|
| 2433 |
+
# TODO: Lock all KG relative DB to esure consistency across multiple processes
|
| 2434 |
async def amerge_entities(
|
| 2435 |
self,
|
| 2436 |
source_entities: list[str],
|
lightrag/llm/openai.py
CHANGED
|
@@ -44,6 +44,47 @@ class InvalidResponseError(Exception):
|
|
| 44 |
pass
|
| 45 |
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
@retry(
|
| 48 |
stop=stop_after_attempt(3),
|
| 49 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
@@ -61,29 +102,52 @@ async def openai_complete_if_cache(
|
|
| 61 |
token_tracker: Any | None = None,
|
| 62 |
**kwargs: Any,
|
| 63 |
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
if history_messages is None:
|
| 65 |
history_messages = []
|
| 66 |
-
if not api_key:
|
| 67 |
-
api_key = os.environ["OPENAI_API_KEY"]
|
| 68 |
-
|
| 69 |
-
default_headers = {
|
| 70 |
-
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
| 71 |
-
"Content-Type": "application/json",
|
| 72 |
-
}
|
| 73 |
|
| 74 |
# Set openai logger level to INFO when VERBOSE_DEBUG is off
|
| 75 |
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
| 76 |
logging.getLogger("openai").setLevel(logging.INFO)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
)
|
|
|
|
|
|
|
| 85 |
kwargs.pop("hashing_kv", None)
|
| 86 |
kwargs.pop("keyword_extraction", None)
|
|
|
|
|
|
|
| 87 |
messages: list[dict[str, Any]] = []
|
| 88 |
if system_prompt:
|
| 89 |
messages.append({"role": "system", "content": system_prompt})
|
|
@@ -272,21 +336,32 @@ async def openai_embed(
|
|
| 272 |
model: str = "text-embedding-3-small",
|
| 273 |
base_url: str = None,
|
| 274 |
api_key: str = None,
|
|
|
|
| 275 |
) -> np.ndarray:
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
)
|
|
|
|
| 290 |
response = await openai_async_client.embeddings.create(
|
| 291 |
model=model, input=texts, encoding_format="float"
|
| 292 |
)
|
|
|
|
| 44 |
pass
|
| 45 |
|
| 46 |
|
| 47 |
+
def create_openai_async_client(
|
| 48 |
+
api_key: str | None = None,
|
| 49 |
+
base_url: str | None = None,
|
| 50 |
+
client_configs: dict[str, Any] = None,
|
| 51 |
+
) -> AsyncOpenAI:
|
| 52 |
+
"""Create an AsyncOpenAI client with the given configuration.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
| 56 |
+
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
|
| 57 |
+
client_configs: Additional configuration options for the AsyncOpenAI client.
|
| 58 |
+
These will override any default configurations but will be overridden by
|
| 59 |
+
explicit parameters (api_key, base_url).
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
An AsyncOpenAI client instance.
|
| 63 |
+
"""
|
| 64 |
+
if not api_key:
|
| 65 |
+
api_key = os.environ["OPENAI_API_KEY"]
|
| 66 |
+
|
| 67 |
+
default_headers = {
|
| 68 |
+
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
| 69 |
+
"Content-Type": "application/json",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
if client_configs is None:
|
| 73 |
+
client_configs = {}
|
| 74 |
+
|
| 75 |
+
# Create a merged config dict with precedence: explicit params > client_configs > defaults
|
| 76 |
+
merged_configs = {
|
| 77 |
+
**client_configs,
|
| 78 |
+
"default_headers": default_headers,
|
| 79 |
+
"api_key": api_key,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
if base_url is not None:
|
| 83 |
+
merged_configs["base_url"] = base_url
|
| 84 |
+
|
| 85 |
+
return AsyncOpenAI(**merged_configs)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
@retry(
|
| 89 |
stop=stop_after_attempt(3),
|
| 90 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
|
|
| 102 |
token_tracker: Any | None = None,
|
| 103 |
**kwargs: Any,
|
| 104 |
) -> str:
|
| 105 |
+
"""Complete a prompt using OpenAI's API with caching support.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
model: The OpenAI model to use.
|
| 109 |
+
prompt: The prompt to complete.
|
| 110 |
+
system_prompt: Optional system prompt to include.
|
| 111 |
+
history_messages: Optional list of previous messages in the conversation.
|
| 112 |
+
base_url: Optional base URL for the OpenAI API.
|
| 113 |
+
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
| 114 |
+
**kwargs: Additional keyword arguments to pass to the OpenAI API.
|
| 115 |
+
Special kwargs:
|
| 116 |
+
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
| 117 |
+
These will be passed to the client constructor but will be overridden by
|
| 118 |
+
explicit parameters (api_key, base_url).
|
| 119 |
+
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
|
| 120 |
+
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
The completed text or an async iterator of text chunks if streaming.
|
| 124 |
+
|
| 125 |
+
Raises:
|
| 126 |
+
InvalidResponseError: If the response from OpenAI is invalid or empty.
|
| 127 |
+
APIConnectionError: If there is a connection error with the OpenAI API.
|
| 128 |
+
RateLimitError: If the OpenAI API rate limit is exceeded.
|
| 129 |
+
APITimeoutError: If the OpenAI API request times out.
|
| 130 |
+
"""
|
| 131 |
if history_messages is None:
|
| 132 |
history_messages = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
# Set openai logger level to INFO when VERBOSE_DEBUG is off
|
| 135 |
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
| 136 |
logging.getLogger("openai").setLevel(logging.INFO)
|
| 137 |
|
| 138 |
+
# Extract client configuration options
|
| 139 |
+
client_configs = kwargs.pop("openai_client_configs", {})
|
| 140 |
+
|
| 141 |
+
# Create the OpenAI client
|
| 142 |
+
openai_async_client = create_openai_async_client(
|
| 143 |
+
api_key=api_key, base_url=base_url, client_configs=client_configs
|
| 144 |
)
|
| 145 |
+
|
| 146 |
+
# Remove special kwargs that shouldn't be passed to OpenAI
|
| 147 |
kwargs.pop("hashing_kv", None)
|
| 148 |
kwargs.pop("keyword_extraction", None)
|
| 149 |
+
|
| 150 |
+
# Prepare messages
|
| 151 |
messages: list[dict[str, Any]] = []
|
| 152 |
if system_prompt:
|
| 153 |
messages.append({"role": "system", "content": system_prompt})
|
|
|
|
| 336 |
model: str = "text-embedding-3-small",
|
| 337 |
base_url: str = None,
|
| 338 |
api_key: str = None,
|
| 339 |
+
client_configs: dict[str, Any] = None,
|
| 340 |
) -> np.ndarray:
|
| 341 |
+
"""Generate embeddings for a list of texts using OpenAI's API.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
texts: List of texts to embed.
|
| 345 |
+
model: The OpenAI embedding model to use.
|
| 346 |
+
base_url: Optional base URL for the OpenAI API.
|
| 347 |
+
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
| 348 |
+
client_configs: Additional configuration options for the AsyncOpenAI client.
|
| 349 |
+
These will override any default configurations but will be overridden by
|
| 350 |
+
explicit parameters (api_key, base_url).
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
A numpy array of embeddings, one per input text.
|
| 354 |
+
|
| 355 |
+
Raises:
|
| 356 |
+
APIConnectionError: If there is a connection error with the OpenAI API.
|
| 357 |
+
RateLimitError: If the OpenAI API rate limit is exceeded.
|
| 358 |
+
APITimeoutError: If the OpenAI API request times out.
|
| 359 |
+
"""
|
| 360 |
+
# Create the OpenAI client
|
| 361 |
+
openai_async_client = create_openai_async_client(
|
| 362 |
+
api_key=api_key, base_url=base_url, client_configs=client_configs
|
| 363 |
)
|
| 364 |
+
|
| 365 |
response = await openai_async_client.embeddings.create(
|
| 366 |
model=model, input=texts, encoding_format="float"
|
| 367 |
)
|
lightrag/operate.py
CHANGED
|
@@ -26,7 +26,6 @@ from .utils import (
|
|
| 26 |
CacheData,
|
| 27 |
statistic_data,
|
| 28 |
get_conversation_turns,
|
| 29 |
-
verbose_debug,
|
| 30 |
)
|
| 31 |
from .base import (
|
| 32 |
BaseGraphStorage,
|
|
@@ -442,6 +441,13 @@ async def extract_entities(
|
|
| 442 |
|
| 443 |
processed_chunks = 0
|
| 444 |
total_chunks = len(ordered_chunks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
async def _user_llm_func_with_cache(
|
| 447 |
input_text: str, history_messages: list[dict[str, str]] = None
|
|
@@ -540,7 +546,7 @@ async def extract_entities(
|
|
| 540 |
chunk_key_dp (tuple[str, TextChunkSchema]):
|
| 541 |
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
| 542 |
"""
|
| 543 |
-
nonlocal processed_chunks
|
| 544 |
chunk_key = chunk_key_dp[0]
|
| 545 |
chunk_dp = chunk_key_dp[1]
|
| 546 |
content = chunk_dp["content"]
|
|
@@ -598,102 +604,74 @@ async def extract_entities(
|
|
| 598 |
async with pipeline_status_lock:
|
| 599 |
pipeline_status["latest_message"] = log_message
|
| 600 |
pipeline_status["history_messages"].append(log_message)
|
| 601 |
-
return dict(maybe_nodes), dict(maybe_edges)
|
| 602 |
|
| 603 |
-
|
| 604 |
-
|
|
|
|
| 605 |
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
for k, v in m_edges.items():
|
| 612 |
-
maybe_edges[tuple(sorted(k))].extend(v)
|
| 613 |
-
|
| 614 |
-
from .kg.shared_storage import get_graph_db_lock
|
| 615 |
-
|
| 616 |
-
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
| 617 |
-
|
| 618 |
-
# Ensure that nodes and edges are merged and upserted atomically
|
| 619 |
-
async with graph_db_lock:
|
| 620 |
-
all_entities_data = await asyncio.gather(
|
| 621 |
-
*[
|
| 622 |
-
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
| 623 |
-
for k, v in maybe_nodes.items()
|
| 624 |
-
]
|
| 625 |
-
)
|
| 626 |
-
|
| 627 |
-
all_relationships_data = await asyncio.gather(
|
| 628 |
-
*[
|
| 629 |
-
_merge_edges_then_upsert(
|
| 630 |
-
k[0], k[1], v, knowledge_graph_inst, global_config
|
| 631 |
)
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
if pipeline_status is not None:
|
| 640 |
-
async with pipeline_status_lock:
|
| 641 |
-
pipeline_status["latest_message"] = log_message
|
| 642 |
-
pipeline_status["history_messages"].append(log_message)
|
| 643 |
-
return
|
| 644 |
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
if pipeline_status is not None:
|
| 649 |
-
async with pipeline_status_lock:
|
| 650 |
-
pipeline_status["latest_message"] = log_message
|
| 651 |
-
pipeline_status["history_messages"].append(log_message)
|
| 652 |
-
if not all_relationships_data:
|
| 653 |
-
log_message = "Didn't extract any relationships"
|
| 654 |
-
logger.info(log_message)
|
| 655 |
-
if pipeline_status is not None:
|
| 656 |
-
async with pipeline_status_lock:
|
| 657 |
-
pipeline_status["latest_message"] = log_message
|
| 658 |
-
pipeline_status["history_messages"].append(log_message)
|
| 659 |
|
| 660 |
-
log_message = f"Extracted {
|
| 661 |
logger.info(log_message)
|
| 662 |
if pipeline_status is not None:
|
| 663 |
async with pipeline_status_lock:
|
| 664 |
pipeline_status["latest_message"] = log_message
|
| 665 |
pipeline_status["history_messages"].append(log_message)
|
| 666 |
-
verbose_debug(
|
| 667 |
-
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
|
| 668 |
-
)
|
| 669 |
-
verbose_debug(f"New relationships:{all_relationships_data}")
|
| 670 |
-
|
| 671 |
-
if entity_vdb is not None:
|
| 672 |
-
data_for_vdb = {
|
| 673 |
-
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
| 674 |
-
"entity_name": dp["entity_name"],
|
| 675 |
-
"entity_type": dp["entity_type"],
|
| 676 |
-
"content": f"{dp['entity_name']}\n{dp['description']}",
|
| 677 |
-
"source_id": dp["source_id"],
|
| 678 |
-
"file_path": dp.get("file_path", "unknown_source"),
|
| 679 |
-
}
|
| 680 |
-
for dp in all_entities_data
|
| 681 |
-
}
|
| 682 |
-
await entity_vdb.upsert(data_for_vdb)
|
| 683 |
-
|
| 684 |
-
if relationships_vdb is not None:
|
| 685 |
-
data_for_vdb = {
|
| 686 |
-
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
| 687 |
-
"src_id": dp["src_id"],
|
| 688 |
-
"tgt_id": dp["tgt_id"],
|
| 689 |
-
"keywords": dp["keywords"],
|
| 690 |
-
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
| 691 |
-
"source_id": dp["source_id"],
|
| 692 |
-
"file_path": dp.get("file_path", "unknown_source"),
|
| 693 |
-
}
|
| 694 |
-
for dp in all_relationships_data
|
| 695 |
-
}
|
| 696 |
-
await relationships_vdb.upsert(data_for_vdb)
|
| 697 |
|
| 698 |
|
| 699 |
async def kg_query(
|
|
@@ -720,8 +698,7 @@ async def kg_query(
|
|
| 720 |
if cached_response is not None:
|
| 721 |
return cached_response
|
| 722 |
|
| 723 |
-
|
| 724 |
-
hl_keywords, ll_keywords = await extract_keywords_only(
|
| 725 |
query, query_param, global_config, hashing_kv
|
| 726 |
)
|
| 727 |
|
|
@@ -817,6 +794,38 @@ async def kg_query(
|
|
| 817 |
return response
|
| 818 |
|
| 819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
async def extract_keywords_only(
|
| 821 |
text: str,
|
| 822 |
param: QueryParam,
|
|
@@ -957,8 +966,7 @@ async def mix_kg_vector_query(
|
|
| 957 |
# 2. Execute knowledge graph and vector searches in parallel
|
| 958 |
async def get_kg_context():
|
| 959 |
try:
|
| 960 |
-
|
| 961 |
-
hl_keywords, ll_keywords = await extract_keywords_only(
|
| 962 |
query, query_param, global_config, hashing_kv
|
| 963 |
)
|
| 964 |
|
|
@@ -1339,7 +1347,9 @@ async def _get_node_data(
|
|
| 1339 |
|
| 1340 |
text_units_section_list = [["id", "content", "file_path"]]
|
| 1341 |
for i, t in enumerate(use_text_units):
|
| 1342 |
-
text_units_section_list.append(
|
|
|
|
|
|
|
| 1343 |
text_units_context = list_of_list_to_csv(text_units_section_list)
|
| 1344 |
return entities_context, relations_context, text_units_context
|
| 1345 |
|
|
@@ -2043,16 +2053,13 @@ async def query_with_keywords(
|
|
| 2043 |
Query response or async iterator
|
| 2044 |
"""
|
| 2045 |
# Extract keywords
|
| 2046 |
-
hl_keywords, ll_keywords = await
|
| 2047 |
-
|
| 2048 |
-
|
| 2049 |
global_config=global_config,
|
| 2050 |
hashing_kv=hashing_kv,
|
| 2051 |
)
|
| 2052 |
|
| 2053 |
-
param.hl_keywords = hl_keywords
|
| 2054 |
-
param.ll_keywords = ll_keywords
|
| 2055 |
-
|
| 2056 |
# Create a new string with the prompt and the keywords
|
| 2057 |
ll_keywords_str = ", ".join(ll_keywords)
|
| 2058 |
hl_keywords_str = ", ".join(hl_keywords)
|
|
|
|
| 26 |
CacheData,
|
| 27 |
statistic_data,
|
| 28 |
get_conversation_turns,
|
|
|
|
| 29 |
)
|
| 30 |
from .base import (
|
| 31 |
BaseGraphStorage,
|
|
|
|
| 441 |
|
| 442 |
processed_chunks = 0
|
| 443 |
total_chunks = len(ordered_chunks)
|
| 444 |
+
total_entities_count = 0
|
| 445 |
+
total_relations_count = 0
|
| 446 |
+
|
| 447 |
+
# Get lock manager from shared storage
|
| 448 |
+
from .kg.shared_storage import get_graph_db_lock
|
| 449 |
+
|
| 450 |
+
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
| 451 |
|
| 452 |
async def _user_llm_func_with_cache(
|
| 453 |
input_text: str, history_messages: list[dict[str, str]] = None
|
|
|
|
| 546 |
chunk_key_dp (tuple[str, TextChunkSchema]):
|
| 547 |
("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
|
| 548 |
"""
|
| 549 |
+
nonlocal processed_chunks, total_entities_count, total_relations_count
|
| 550 |
chunk_key = chunk_key_dp[0]
|
| 551 |
chunk_dp = chunk_key_dp[1]
|
| 552 |
content = chunk_dp["content"]
|
|
|
|
| 604 |
async with pipeline_status_lock:
|
| 605 |
pipeline_status["latest_message"] = log_message
|
| 606 |
pipeline_status["history_messages"].append(log_message)
|
|
|
|
| 607 |
|
| 608 |
+
# Use graph database lock to ensure atomic merges and updates
|
| 609 |
+
chunk_entities_data = []
|
| 610 |
+
chunk_relationships_data = []
|
| 611 |
|
| 612 |
+
async with graph_db_lock:
|
| 613 |
+
# Process and update entities
|
| 614 |
+
for entity_name, entities in maybe_nodes.items():
|
| 615 |
+
entity_data = await _merge_nodes_then_upsert(
|
| 616 |
+
entity_name, entities, knowledge_graph_inst, global_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
)
|
| 618 |
+
chunk_entities_data.append(entity_data)
|
| 619 |
+
|
| 620 |
+
# Process and update relationships
|
| 621 |
+
for edge_key, edges in maybe_edges.items():
|
| 622 |
+
# Ensure edge direction consistency
|
| 623 |
+
sorted_edge_key = tuple(sorted(edge_key))
|
| 624 |
+
edge_data = await _merge_edges_then_upsert(
|
| 625 |
+
sorted_edge_key[0],
|
| 626 |
+
sorted_edge_key[1],
|
| 627 |
+
edges,
|
| 628 |
+
knowledge_graph_inst,
|
| 629 |
+
global_config,
|
| 630 |
+
)
|
| 631 |
+
chunk_relationships_data.append(edge_data)
|
| 632 |
+
|
| 633 |
+
# Update vector database (within the same lock to ensure atomicity)
|
| 634 |
+
if entity_vdb is not None and chunk_entities_data:
|
| 635 |
+
data_for_vdb = {
|
| 636 |
+
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
| 637 |
+
"entity_name": dp["entity_name"],
|
| 638 |
+
"entity_type": dp["entity_type"],
|
| 639 |
+
"content": f"{dp['entity_name']}\n{dp['description']}",
|
| 640 |
+
"source_id": dp["source_id"],
|
| 641 |
+
"file_path": dp.get("file_path", "unknown_source"),
|
| 642 |
+
}
|
| 643 |
+
for dp in chunk_entities_data
|
| 644 |
+
}
|
| 645 |
+
await entity_vdb.upsert(data_for_vdb)
|
| 646 |
+
|
| 647 |
+
if relationships_vdb is not None and chunk_relationships_data:
|
| 648 |
+
data_for_vdb = {
|
| 649 |
+
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
| 650 |
+
"src_id": dp["src_id"],
|
| 651 |
+
"tgt_id": dp["tgt_id"],
|
| 652 |
+
"keywords": dp["keywords"],
|
| 653 |
+
"content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
|
| 654 |
+
"source_id": dp["source_id"],
|
| 655 |
+
"file_path": dp.get("file_path", "unknown_source"),
|
| 656 |
+
}
|
| 657 |
+
for dp in chunk_relationships_data
|
| 658 |
+
}
|
| 659 |
+
await relationships_vdb.upsert(data_for_vdb)
|
| 660 |
|
| 661 |
+
# Update counters
|
| 662 |
+
total_entities_count += len(chunk_entities_data)
|
| 663 |
+
total_relations_count += len(chunk_relationships_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
|
| 665 |
+
# Handle all chunks in parallel
|
| 666 |
+
tasks = [_process_single_content(c) for c in ordered_chunks]
|
| 667 |
+
await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
|
| 669 |
+
log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
|
| 670 |
logger.info(log_message)
|
| 671 |
if pipeline_status is not None:
|
| 672 |
async with pipeline_status_lock:
|
| 673 |
pipeline_status["latest_message"] = log_message
|
| 674 |
pipeline_status["history_messages"].append(log_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
|
| 676 |
|
| 677 |
async def kg_query(
|
|
|
|
| 698 |
if cached_response is not None:
|
| 699 |
return cached_response
|
| 700 |
|
| 701 |
+
hl_keywords, ll_keywords = await get_keywords_from_query(
|
|
|
|
| 702 |
query, query_param, global_config, hashing_kv
|
| 703 |
)
|
| 704 |
|
|
|
|
| 794 |
return response
|
| 795 |
|
| 796 |
|
| 797 |
+
async def get_keywords_from_query(
|
| 798 |
+
query: str,
|
| 799 |
+
query_param: QueryParam,
|
| 800 |
+
global_config: dict[str, str],
|
| 801 |
+
hashing_kv: BaseKVStorage | None = None,
|
| 802 |
+
) -> tuple[list[str], list[str]]:
|
| 803 |
+
"""
|
| 804 |
+
Retrieves high-level and low-level keywords for RAG operations.
|
| 805 |
+
|
| 806 |
+
This function checks if keywords are already provided in query parameters,
|
| 807 |
+
and if not, extracts them from the query text using LLM.
|
| 808 |
+
|
| 809 |
+
Args:
|
| 810 |
+
query: The user's query text
|
| 811 |
+
query_param: Query parameters that may contain pre-defined keywords
|
| 812 |
+
global_config: Global configuration dictionary
|
| 813 |
+
hashing_kv: Optional key-value storage for caching results
|
| 814 |
+
|
| 815 |
+
Returns:
|
| 816 |
+
A tuple containing (high_level_keywords, low_level_keywords)
|
| 817 |
+
"""
|
| 818 |
+
# Check if pre-defined keywords are already provided
|
| 819 |
+
if query_param.hl_keywords or query_param.ll_keywords:
|
| 820 |
+
return query_param.hl_keywords, query_param.ll_keywords
|
| 821 |
+
|
| 822 |
+
# Extract keywords using extract_keywords_only function which already supports conversation history
|
| 823 |
+
hl_keywords, ll_keywords = await extract_keywords_only(
|
| 824 |
+
query, query_param, global_config, hashing_kv
|
| 825 |
+
)
|
| 826 |
+
return hl_keywords, ll_keywords
|
| 827 |
+
|
| 828 |
+
|
| 829 |
async def extract_keywords_only(
|
| 830 |
text: str,
|
| 831 |
param: QueryParam,
|
|
|
|
| 966 |
# 2. Execute knowledge graph and vector searches in parallel
|
| 967 |
async def get_kg_context():
|
| 968 |
try:
|
| 969 |
+
hl_keywords, ll_keywords = await get_keywords_from_query(
|
|
|
|
| 970 |
query, query_param, global_config, hashing_kv
|
| 971 |
)
|
| 972 |
|
|
|
|
| 1347 |
|
| 1348 |
text_units_section_list = [["id", "content", "file_path"]]
|
| 1349 |
for i, t in enumerate(use_text_units):
|
| 1350 |
+
text_units_section_list.append(
|
| 1351 |
+
[i, t["content"], t.get("file_path", "unknown_source")]
|
| 1352 |
+
)
|
| 1353 |
text_units_context = list_of_list_to_csv(text_units_section_list)
|
| 1354 |
return entities_context, relations_context, text_units_context
|
| 1355 |
|
|
|
|
| 2053 |
Query response or async iterator
|
| 2054 |
"""
|
| 2055 |
# Extract keywords
|
| 2056 |
+
hl_keywords, ll_keywords = await get_keywords_from_query(
|
| 2057 |
+
query=query,
|
| 2058 |
+
query_param=param,
|
| 2059 |
global_config=global_config,
|
| 2060 |
hashing_kv=hashing_kv,
|
| 2061 |
)
|
| 2062 |
|
|
|
|
|
|
|
|
|
|
| 2063 |
# Create a new string with the prompt and the keywords
|
| 2064 |
ll_keywords_str = ", ".join(ll_keywords)
|
| 2065 |
hl_keywords_str = ", ".join(hl_keywords)
|
lightrag/types.py
CHANGED
|
@@ -26,3 +26,4 @@ class KnowledgeGraphEdge(BaseModel):
|
|
| 26 |
class KnowledgeGraph(BaseModel):
|
| 27 |
nodes: list[KnowledgeGraphNode] = []
|
| 28 |
edges: list[KnowledgeGraphEdge] = []
|
|
|
|
|
|
| 26 |
class KnowledgeGraph(BaseModel):
|
| 27 |
nodes: list[KnowledgeGraphNode] = []
|
| 28 |
edges: list[KnowledgeGraphEdge] = []
|
| 29 |
+
is_truncated: bool = False
|
lightrag_webui/src/App.tsx
CHANGED
|
@@ -3,12 +3,13 @@ import ThemeProvider from '@/components/ThemeProvider'
|
|
| 3 |
import TabVisibilityProvider from '@/contexts/TabVisibilityProvider'
|
| 4 |
import ApiKeyAlert from '@/components/ApiKeyAlert'
|
| 5 |
import StatusIndicator from '@/components/status/StatusIndicator'
|
| 6 |
-
import { healthCheckInterval } from '@/lib/constants'
|
| 7 |
import { useBackendState, useAuthStore } from '@/stores/state'
|
| 8 |
import { useSettingsStore } from '@/stores/settings'
|
| 9 |
import { getAuthStatus } from '@/api/lightrag'
|
| 10 |
import SiteHeader from '@/features/SiteHeader'
|
| 11 |
import { InvalidApiKeyError, RequireApiKeError } from '@/api/lightrag'
|
|
|
|
| 12 |
|
| 13 |
import GraphViewer from '@/features/GraphViewer'
|
| 14 |
import DocumentManager from '@/features/DocumentManager'
|
|
@@ -22,6 +23,7 @@ function App() {
|
|
| 22 |
const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
|
| 23 |
const currentTab = useSettingsStore.use.currentTab()
|
| 24 |
const [apiKeyAlertOpen, setApiKeyAlertOpen] = useState(false)
|
|
|
|
| 25 |
const versionCheckRef = useRef(false); // Prevent duplicate calls in Vite dev mode
|
| 26 |
|
| 27 |
const handleApiKeyAlertOpenChange = useCallback((open: boolean) => {
|
|
@@ -55,29 +57,48 @@ function App() {
|
|
| 55 |
|
| 56 |
// Check if version info was already obtained in login page
|
| 57 |
const versionCheckedFromLogin = sessionStorage.getItem('VERSION_CHECKED_FROM_LOGIN') === 'true';
|
| 58 |
-
if (versionCheckedFromLogin)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
if (!token) return;
|
| 63 |
|
| 64 |
try {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
const status = await getAuthStatus();
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
const isGuestMode = status.auth_mode === 'disabled' || useAuthStore.getState().isGuestMode;
|
| 68 |
-
// Update version info while maintaining login state
|
| 69 |
useAuthStore.getState().login(
|
| 70 |
token,
|
| 71 |
isGuestMode,
|
| 72 |
status.core_version,
|
| 73 |
-
status.api_version
|
|
|
|
|
|
|
| 74 |
);
|
| 75 |
-
|
| 76 |
-
// Set flag to indicate version info has been checked
|
| 77 |
-
sessionStorage.setItem('VERSION_CHECKED_FROM_LOGIN', 'true');
|
| 78 |
}
|
|
|
|
|
|
|
|
|
|
| 79 |
} catch (error) {
|
| 80 |
console.error('Failed to get version info:', error);
|
|
|
|
|
|
|
|
|
|
| 81 |
}
|
| 82 |
};
|
| 83 |
|
|
@@ -101,31 +122,63 @@ function App() {
|
|
| 101 |
return (
|
| 102 |
<ThemeProvider>
|
| 103 |
<TabVisibilityProvider>
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
</
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
<
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
</div>
|
| 125 |
-
</
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
</TabVisibilityProvider>
|
| 130 |
</ThemeProvider>
|
| 131 |
)
|
|
|
|
| 3 |
import TabVisibilityProvider from '@/contexts/TabVisibilityProvider'
|
| 4 |
import ApiKeyAlert from '@/components/ApiKeyAlert'
|
| 5 |
import StatusIndicator from '@/components/status/StatusIndicator'
|
| 6 |
+
import { healthCheckInterval, SiteInfo, webuiPrefix } from '@/lib/constants'
|
| 7 |
import { useBackendState, useAuthStore } from '@/stores/state'
|
| 8 |
import { useSettingsStore } from '@/stores/settings'
|
| 9 |
import { getAuthStatus } from '@/api/lightrag'
|
| 10 |
import SiteHeader from '@/features/SiteHeader'
|
| 11 |
import { InvalidApiKeyError, RequireApiKeError } from '@/api/lightrag'
|
| 12 |
+
import { ZapIcon } from 'lucide-react'
|
| 13 |
|
| 14 |
import GraphViewer from '@/features/GraphViewer'
|
| 15 |
import DocumentManager from '@/features/DocumentManager'
|
|
|
|
| 23 |
const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
|
| 24 |
const currentTab = useSettingsStore.use.currentTab()
|
| 25 |
const [apiKeyAlertOpen, setApiKeyAlertOpen] = useState(false)
|
| 26 |
+
const [initializing, setInitializing] = useState(true) // Add initializing state
|
| 27 |
const versionCheckRef = useRef(false); // Prevent duplicate calls in Vite dev mode
|
| 28 |
|
| 29 |
const handleApiKeyAlertOpenChange = useCallback((open: boolean) => {
|
|
|
|
| 57 |
|
| 58 |
// Check if version info was already obtained in login page
|
| 59 |
const versionCheckedFromLogin = sessionStorage.getItem('VERSION_CHECKED_FROM_LOGIN') === 'true';
|
| 60 |
+
if (versionCheckedFromLogin) {
|
| 61 |
+
setInitializing(false); // Skip initialization if already checked
|
| 62 |
+
return;
|
| 63 |
+
}
|
|
|
|
| 64 |
|
| 65 |
try {
|
| 66 |
+
setInitializing(true); // Start initialization
|
| 67 |
+
|
| 68 |
+
// Get version info
|
| 69 |
+
const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
|
| 70 |
const status = await getAuthStatus();
|
| 71 |
+
|
| 72 |
+
// If auth is not configured and a new token is returned, use the new token
|
| 73 |
+
if (!status.auth_configured && status.access_token) {
|
| 74 |
+
useAuthStore.getState().login(
|
| 75 |
+
status.access_token, // Use the new token
|
| 76 |
+
true, // Guest mode
|
| 77 |
+
status.core_version,
|
| 78 |
+
status.api_version,
|
| 79 |
+
status.webui_title || null,
|
| 80 |
+
status.webui_description || null
|
| 81 |
+
);
|
| 82 |
+
} else if (token && (status.core_version || status.api_version || status.webui_title || status.webui_description)) {
|
| 83 |
+
// Otherwise use the old token (if it exists)
|
| 84 |
const isGuestMode = status.auth_mode === 'disabled' || useAuthStore.getState().isGuestMode;
|
|
|
|
| 85 |
useAuthStore.getState().login(
|
| 86 |
token,
|
| 87 |
isGuestMode,
|
| 88 |
status.core_version,
|
| 89 |
+
status.api_version,
|
| 90 |
+
status.webui_title || null,
|
| 91 |
+
status.webui_description || null
|
| 92 |
);
|
|
|
|
|
|
|
|
|
|
| 93 |
}
|
| 94 |
+
|
| 95 |
+
// Set flag to indicate version info has been checked
|
| 96 |
+
sessionStorage.setItem('VERSION_CHECKED_FROM_LOGIN', 'true');
|
| 97 |
} catch (error) {
|
| 98 |
console.error('Failed to get version info:', error);
|
| 99 |
+
} finally {
|
| 100 |
+
// Ensure initializing is set to false even if there's an error
|
| 101 |
+
setInitializing(false);
|
| 102 |
}
|
| 103 |
};
|
| 104 |
|
|
|
|
| 122 |
return (
|
| 123 |
<ThemeProvider>
|
| 124 |
<TabVisibilityProvider>
|
| 125 |
+
{initializing ? (
|
| 126 |
+
// Loading state while initializing with simplified header
|
| 127 |
+
<div className="flex h-screen w-screen flex-col">
|
| 128 |
+
{/* Simplified header during initialization - matches SiteHeader structure */}
|
| 129 |
+
<header className="border-border/40 bg-background/95 supports-[backdrop-filter]:bg-background/60 sticky top-0 z-50 flex h-10 w-full border-b px-4 backdrop-blur">
|
| 130 |
+
<div className="min-w-[200px] w-auto flex items-center">
|
| 131 |
+
<a href={webuiPrefix} className="flex items-center gap-2">
|
| 132 |
+
<ZapIcon className="size-4 text-emerald-400" aria-hidden="true" />
|
| 133 |
+
<span className="font-bold md:inline-block">{SiteInfo.name}</span>
|
| 134 |
+
</a>
|
| 135 |
+
</div>
|
| 136 |
+
|
| 137 |
+
{/* Empty middle section to maintain layout */}
|
| 138 |
+
<div className="flex h-10 flex-1 items-center justify-center">
|
| 139 |
+
</div>
|
| 140 |
+
|
| 141 |
+
{/* Empty right section to maintain layout */}
|
| 142 |
+
<nav className="w-[200px] flex items-center justify-end">
|
| 143 |
+
</nav>
|
| 144 |
+
</header>
|
| 145 |
+
|
| 146 |
+
{/* Loading indicator in content area */}
|
| 147 |
+
<div className="flex flex-1 items-center justify-center">
|
| 148 |
+
<div className="text-center">
|
| 149 |
+
<div className="mb-2 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"></div>
|
| 150 |
+
<p>Initializing...</p>
|
| 151 |
+
</div>
|
| 152 |
</div>
|
| 153 |
+
</div>
|
| 154 |
+
) : (
|
| 155 |
+
// Main content after initialization
|
| 156 |
+
<main className="flex h-screen w-screen overflow-hidden">
|
| 157 |
+
<Tabs
|
| 158 |
+
defaultValue={currentTab}
|
| 159 |
+
className="!m-0 flex grow flex-col !p-0 overflow-hidden"
|
| 160 |
+
onValueChange={handleTabChange}
|
| 161 |
+
>
|
| 162 |
+
<SiteHeader />
|
| 163 |
+
<div className="relative grow">
|
| 164 |
+
<TabsContent value="documents" className="absolute top-0 right-0 bottom-0 left-0 overflow-auto">
|
| 165 |
+
<DocumentManager />
|
| 166 |
+
</TabsContent>
|
| 167 |
+
<TabsContent value="knowledge-graph" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
|
| 168 |
+
<GraphViewer />
|
| 169 |
+
</TabsContent>
|
| 170 |
+
<TabsContent value="retrieval" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
|
| 171 |
+
<RetrievalTesting />
|
| 172 |
+
</TabsContent>
|
| 173 |
+
<TabsContent value="api" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
|
| 174 |
+
<ApiSite />
|
| 175 |
+
</TabsContent>
|
| 176 |
+
</div>
|
| 177 |
+
</Tabs>
|
| 178 |
+
{enableHealthCheck && <StatusIndicator />}
|
| 179 |
+
<ApiKeyAlert open={apiKeyAlertOpen} onOpenChange={handleApiKeyAlertOpenChange} />
|
| 180 |
+
</main>
|
| 181 |
+
)}
|
| 182 |
</TabVisibilityProvider>
|
| 183 |
</ThemeProvider>
|
| 184 |
)
|
lightrag_webui/src/AppRouter.tsx
CHANGED
|
@@ -80,7 +80,12 @@ const AppRouter = () => {
|
|
| 80 |
<ThemeProvider>
|
| 81 |
<Router>
|
| 82 |
<AppContent />
|
| 83 |
-
<Toaster
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
</Router>
|
| 85 |
</ThemeProvider>
|
| 86 |
)
|
|
|
|
| 80 |
<ThemeProvider>
|
| 81 |
<Router>
|
| 82 |
<AppContent />
|
| 83 |
+
<Toaster
|
| 84 |
+
position="bottom-center"
|
| 85 |
+
theme="system"
|
| 86 |
+
closeButton
|
| 87 |
+
richColors
|
| 88 |
+
/>
|
| 89 |
</Router>
|
| 90 |
</ThemeProvider>
|
| 91 |
)
|
lightrag_webui/src/api/lightrag.ts
CHANGED
|
@@ -46,6 +46,8 @@ export type LightragStatus = {
|
|
| 46 |
api_version?: string
|
| 47 |
auth_mode?: 'enabled' | 'disabled'
|
| 48 |
pipeline_busy: boolean
|
|
|
|
|
|
|
| 49 |
}
|
| 50 |
|
| 51 |
export type LightragDocumentsScanProgress = {
|
|
@@ -140,6 +142,8 @@ export type AuthStatusResponse = {
|
|
| 140 |
message?: string
|
| 141 |
core_version?: string
|
| 142 |
api_version?: string
|
|
|
|
|
|
|
| 143 |
}
|
| 144 |
|
| 145 |
export type PipelineStatusResponse = {
|
|
@@ -163,6 +167,8 @@ export type LoginResponse = {
|
|
| 163 |
message?: string // Optional message
|
| 164 |
core_version?: string
|
| 165 |
api_version?: string
|
|
|
|
|
|
|
| 166 |
}
|
| 167 |
|
| 168 |
export const InvalidApiKeyError = 'Invalid API Key'
|
|
@@ -221,9 +227,9 @@ axiosInstance.interceptors.response.use(
|
|
| 221 |
export const queryGraphs = async (
|
| 222 |
label: string,
|
| 223 |
maxDepth: number,
|
| 224 |
-
|
| 225 |
): Promise<LightragGraphType> => {
|
| 226 |
-
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&
|
| 227 |
return response.data
|
| 228 |
}
|
| 229 |
|
|
@@ -382,6 +388,14 @@ export const clearDocuments = async (): Promise<DocActionResponse> => {
|
|
| 382 |
return response.data
|
| 383 |
}
|
| 384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
|
| 386 |
try {
|
| 387 |
// Add a timeout to the request to prevent hanging
|
|
|
|
| 46 |
api_version?: string
|
| 47 |
auth_mode?: 'enabled' | 'disabled'
|
| 48 |
pipeline_busy: boolean
|
| 49 |
+
webui_title?: string
|
| 50 |
+
webui_description?: string
|
| 51 |
}
|
| 52 |
|
| 53 |
export type LightragDocumentsScanProgress = {
|
|
|
|
| 142 |
message?: string
|
| 143 |
core_version?: string
|
| 144 |
api_version?: string
|
| 145 |
+
webui_title?: string
|
| 146 |
+
webui_description?: string
|
| 147 |
}
|
| 148 |
|
| 149 |
export type PipelineStatusResponse = {
|
|
|
|
| 167 |
message?: string // Optional message
|
| 168 |
core_version?: string
|
| 169 |
api_version?: string
|
| 170 |
+
webui_title?: string
|
| 171 |
+
webui_description?: string
|
| 172 |
}
|
| 173 |
|
| 174 |
export const InvalidApiKeyError = 'Invalid API Key'
|
|
|
|
| 227 |
export const queryGraphs = async (
|
| 228 |
label: string,
|
| 229 |
maxDepth: number,
|
| 230 |
+
maxNodes: number
|
| 231 |
): Promise<LightragGraphType> => {
|
| 232 |
+
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&max_nodes=${maxNodes}`)
|
| 233 |
return response.data
|
| 234 |
}
|
| 235 |
|
|
|
|
| 388 |
return response.data
|
| 389 |
}
|
| 390 |
|
| 391 |
+
export const clearCache = async (modes?: string[]): Promise<{
|
| 392 |
+
status: 'success' | 'fail'
|
| 393 |
+
message: string
|
| 394 |
+
}> => {
|
| 395 |
+
const response = await axiosInstance.post('/documents/clear_cache', { modes })
|
| 396 |
+
return response.data
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
|
| 400 |
try {
|
| 401 |
// Add a timeout to the request to prevent hanging
|
lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import { useState, useCallback } from 'react'
|
| 2 |
import Button from '@/components/ui/Button'
|
| 3 |
import {
|
| 4 |
Dialog,
|
|
@@ -6,32 +6,88 @@ import {
|
|
| 6 |
DialogDescription,
|
| 7 |
DialogHeader,
|
| 8 |
DialogTitle,
|
| 9 |
-
DialogTrigger
|
|
|
|
| 10 |
} from '@/components/ui/Dialog'
|
|
|
|
|
|
|
| 11 |
import { toast } from 'sonner'
|
| 12 |
import { errorMessage } from '@/lib/utils'
|
| 13 |
-
import { clearDocuments } from '@/api/lightrag'
|
| 14 |
|
| 15 |
-
import { EraserIcon } from 'lucide-react'
|
| 16 |
import { useTranslation } from 'react-i18next'
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
const { t } = useTranslation()
|
| 20 |
const [open, setOpen] = useState(false)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
const handleClear = useCallback(async () => {
|
|
|
|
|
|
|
| 23 |
try {
|
| 24 |
const result = await clearDocuments()
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
setOpen(false)
|
| 28 |
-
} else {
|
| 29 |
toast.error(t('documentPanel.clearDocuments.failed', { message: result.message }))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
}
|
|
|
|
|
|
|
|
|
|
| 31 |
} catch (err) {
|
| 32 |
toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) }))
|
|
|
|
| 33 |
}
|
| 34 |
-
}, [setOpen, t])
|
| 35 |
|
| 36 |
return (
|
| 37 |
<Dialog open={open} onOpenChange={setOpen}>
|
|
@@ -42,12 +98,60 @@ export default function ClearDocumentsDialog() {
|
|
| 42 |
</DialogTrigger>
|
| 43 |
<DialogContent className="sm:max-w-xl" onCloseAutoFocus={(e) => e.preventDefault()}>
|
| 44 |
<DialogHeader>
|
| 45 |
-
<DialogTitle>
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
</DialogHeader>
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
</DialogContent>
|
| 52 |
</Dialog>
|
| 53 |
)
|
|
|
|
| 1 |
+
import { useState, useCallback, useEffect } from 'react'
|
| 2 |
import Button from '@/components/ui/Button'
|
| 3 |
import {
|
| 4 |
Dialog,
|
|
|
|
| 6 |
DialogDescription,
|
| 7 |
DialogHeader,
|
| 8 |
DialogTitle,
|
| 9 |
+
DialogTrigger,
|
| 10 |
+
DialogFooter
|
| 11 |
} from '@/components/ui/Dialog'
|
| 12 |
+
import Input from '@/components/ui/Input'
|
| 13 |
+
import Checkbox from '@/components/ui/Checkbox'
|
| 14 |
import { toast } from 'sonner'
|
| 15 |
import { errorMessage } from '@/lib/utils'
|
| 16 |
+
import { clearDocuments, clearCache } from '@/api/lightrag'
|
| 17 |
|
| 18 |
+
import { EraserIcon, AlertTriangleIcon } from 'lucide-react'
|
| 19 |
import { useTranslation } from 'react-i18next'
|
| 20 |
|
| 21 |
+
// 简单的Label组件
|
| 22 |
+
const Label = ({
|
| 23 |
+
htmlFor,
|
| 24 |
+
className,
|
| 25 |
+
children,
|
| 26 |
+
...props
|
| 27 |
+
}: React.LabelHTMLAttributes<HTMLLabelElement>) => (
|
| 28 |
+
<label
|
| 29 |
+
htmlFor={htmlFor}
|
| 30 |
+
className={className}
|
| 31 |
+
{...props}
|
| 32 |
+
>
|
| 33 |
+
{children}
|
| 34 |
+
</label>
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
interface ClearDocumentsDialogProps {
|
| 38 |
+
onDocumentsCleared?: () => Promise<void>
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
export default function ClearDocumentsDialog({ onDocumentsCleared }: ClearDocumentsDialogProps) {
|
| 42 |
const { t } = useTranslation()
|
| 43 |
const [open, setOpen] = useState(false)
|
| 44 |
+
const [confirmText, setConfirmText] = useState('')
|
| 45 |
+
const [clearCacheOption, setClearCacheOption] = useState(false)
|
| 46 |
+
const isConfirmEnabled = confirmText.toLowerCase() === 'yes'
|
| 47 |
+
|
| 48 |
+
// 重置状态当对话框关闭时
|
| 49 |
+
useEffect(() => {
|
| 50 |
+
if (!open) {
|
| 51 |
+
setConfirmText('')
|
| 52 |
+
setClearCacheOption(false)
|
| 53 |
+
}
|
| 54 |
+
}, [open])
|
| 55 |
|
| 56 |
const handleClear = useCallback(async () => {
|
| 57 |
+
if (!isConfirmEnabled) return
|
| 58 |
+
|
| 59 |
try {
|
| 60 |
const result = await clearDocuments()
|
| 61 |
+
|
| 62 |
+
if (result.status !== 'success') {
|
|
|
|
|
|
|
| 63 |
toast.error(t('documentPanel.clearDocuments.failed', { message: result.message }))
|
| 64 |
+
setConfirmText('')
|
| 65 |
+
return
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
toast.success(t('documentPanel.clearDocuments.success'))
|
| 69 |
+
|
| 70 |
+
if (clearCacheOption) {
|
| 71 |
+
try {
|
| 72 |
+
await clearCache()
|
| 73 |
+
toast.success(t('documentPanel.clearDocuments.cacheCleared'))
|
| 74 |
+
} catch (cacheErr) {
|
| 75 |
+
toast.error(t('documentPanel.clearDocuments.cacheClearFailed', { error: errorMessage(cacheErr) }))
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
// Refresh document list if provided
|
| 80 |
+
if (onDocumentsCleared) {
|
| 81 |
+
onDocumentsCleared().catch(console.error)
|
| 82 |
}
|
| 83 |
+
|
| 84 |
+
// 所有操作成功后关闭对话框
|
| 85 |
+
setOpen(false)
|
| 86 |
} catch (err) {
|
| 87 |
toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) }))
|
| 88 |
+
setConfirmText('')
|
| 89 |
}
|
| 90 |
+
}, [isConfirmEnabled, clearCacheOption, setOpen, t, onDocumentsCleared])
|
| 91 |
|
| 92 |
return (
|
| 93 |
<Dialog open={open} onOpenChange={setOpen}>
|
|
|
|
| 98 |
</DialogTrigger>
|
| 99 |
<DialogContent className="sm:max-w-xl" onCloseAutoFocus={(e) => e.preventDefault()}>
|
| 100 |
<DialogHeader>
|
| 101 |
+
<DialogTitle className="flex items-center gap-2 text-red-500 dark:text-red-400 font-bold">
|
| 102 |
+
<AlertTriangleIcon className="h-5 w-5" />
|
| 103 |
+
{t('documentPanel.clearDocuments.title')}
|
| 104 |
+
</DialogTitle>
|
| 105 |
+
<DialogDescription className="pt-2">
|
| 106 |
+
{t('documentPanel.clearDocuments.description')}
|
| 107 |
+
</DialogDescription>
|
| 108 |
</DialogHeader>
|
| 109 |
+
|
| 110 |
+
<div className="text-red-500 dark:text-red-400 font-semibold mb-4">
|
| 111 |
+
{t('documentPanel.clearDocuments.warning')}
|
| 112 |
+
</div>
|
| 113 |
+
<div className="mb-4">
|
| 114 |
+
{t('documentPanel.clearDocuments.confirm')}
|
| 115 |
+
</div>
|
| 116 |
+
|
| 117 |
+
<div className="space-y-4">
|
| 118 |
+
<div className="space-y-2">
|
| 119 |
+
<Label htmlFor="confirm-text" className="text-sm font-medium">
|
| 120 |
+
{t('documentPanel.clearDocuments.confirmPrompt')}
|
| 121 |
+
</Label>
|
| 122 |
+
<Input
|
| 123 |
+
id="confirm-text"
|
| 124 |
+
value={confirmText}
|
| 125 |
+
onChange={(e: React.ChangeEvent<HTMLInputElement>) => setConfirmText(e.target.value)}
|
| 126 |
+
placeholder={t('documentPanel.clearDocuments.confirmPlaceholder')}
|
| 127 |
+
className="w-full"
|
| 128 |
+
/>
|
| 129 |
+
</div>
|
| 130 |
+
|
| 131 |
+
<div className="flex items-center space-x-2">
|
| 132 |
+
<Checkbox
|
| 133 |
+
id="clear-cache"
|
| 134 |
+
checked={clearCacheOption}
|
| 135 |
+
onCheckedChange={(checked: boolean | 'indeterminate') => setClearCacheOption(checked === true)}
|
| 136 |
+
/>
|
| 137 |
+
<Label htmlFor="clear-cache" className="text-sm font-medium cursor-pointer">
|
| 138 |
+
{t('documentPanel.clearDocuments.clearCache')}
|
| 139 |
+
</Label>
|
| 140 |
+
</div>
|
| 141 |
+
</div>
|
| 142 |
+
|
| 143 |
+
<DialogFooter>
|
| 144 |
+
<Button variant="outline" onClick={() => setOpen(false)}>
|
| 145 |
+
{t('common.cancel')}
|
| 146 |
+
</Button>
|
| 147 |
+
<Button
|
| 148 |
+
variant="destructive"
|
| 149 |
+
onClick={handleClear}
|
| 150 |
+
disabled={!isConfirmEnabled}
|
| 151 |
+
>
|
| 152 |
+
{t('documentPanel.clearDocuments.confirmButton')}
|
| 153 |
+
</Button>
|
| 154 |
+
</DialogFooter>
|
| 155 |
</DialogContent>
|
| 156 |
</Dialog>
|
| 157 |
)
|