czhang17 commited on
Commit
de6166a
·
unverified ·
2 Parent(s): 561b7e9 ae19f4a

Merge branch 'main' into main

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README-zh.md +52 -4
  2. README.md +54 -5
  3. config.ini.example +0 -17
  4. env.example +20 -27
  5. examples/lightrag_api_ollama_demo.py +0 -188
  6. examples/lightrag_api_openai_compatible_demo.py +0 -204
  7. examples/lightrag_api_oracle_demo.py +0 -267
  8. examples/lightrag_ollama_gremlin_demo.py +4 -0
  9. examples/lightrag_oracle_demo.py +0 -141
  10. examples/lightrag_tidb_demo.py +4 -0
  11. lightrag/api/README-zh.md +4 -12
  12. lightrag/api/README.md +5 -13
  13. lightrag/api/__init__.py +1 -1
  14. lightrag/api/auth.py +9 -8
  15. lightrag/api/config.py +335 -0
  16. lightrag/api/lightrag_server.py +39 -19
  17. lightrag/api/routers/document_routes.py +465 -52
  18. lightrag/api/routers/graph_routes.py +10 -14
  19. lightrag/api/run_with_gunicorn.py +31 -25
  20. lightrag/api/utils_api.py +19 -354
  21. lightrag/api/webui/assets/index-CD5HxTy1.css +0 -0
  22. lightrag/api/webui/assets/{index-raheqJeu.js → index-Cma7xY0-.js} +0 -0
  23. lightrag/api/webui/assets/index-QU59h9JG.css +0 -0
  24. lightrag/api/webui/index.html +0 -0
  25. lightrag/base.py +122 -9
  26. lightrag/kg/__init__.py +15 -40
  27. lightrag/kg/age_impl.py +21 -3
  28. lightrag/kg/chroma_impl.py +28 -2
  29. lightrag/kg/faiss_impl.py +64 -5
  30. lightrag/kg/gremlin_impl.py +24 -3
  31. lightrag/kg/json_doc_status_impl.py +49 -10
  32. lightrag/kg/json_kv_impl.py +73 -3
  33. lightrag/kg/milvus_impl.py +31 -1
  34. lightrag/kg/mongo_impl.py +130 -3
  35. lightrag/kg/nano_vector_db_impl.py +66 -0
  36. lightrag/kg/neo4j_impl.py +373 -246
  37. lightrag/kg/networkx_impl.py +122 -97
  38. lightrag/kg/oracle_impl.py +0 -1346
  39. lightrag/kg/postgres_impl.py +382 -319
  40. lightrag/kg/qdrant_impl.py +93 -6
  41. lightrag/kg/redis_impl.py +46 -60
  42. lightrag/kg/tidb_impl.py +172 -3
  43. lightrag/lightrag.py +25 -42
  44. lightrag/llm/openai.py +101 -26
  45. lightrag/operate.py +105 -98
  46. lightrag/types.py +1 -0
  47. lightrag_webui/src/App.tsx +89 -36
  48. lightrag_webui/src/AppRouter.tsx +6 -1
  49. lightrag_webui/src/api/lightrag.ts +16 -2
  50. lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx +119 -15
README-zh.md CHANGED
@@ -11,7 +11,6 @@
11
  - [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
12
  - [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。
13
  - [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。
14
- - [X] [2024.11.12]🎯📢LightRAG现在支持[Oracle Database 23ai的所有存储类型(KV、向量和图)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py)。
15
  - [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
16
  - [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app),允许您插入、查询、可视化和下载LightRAG知识。
17
  - [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。
@@ -410,6 +409,54 @@ if __name__ == "__main__":
410
 
411
  </details>
412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  ### 对话历史
414
 
415
  LightRAG现在通过对话历史功能支持多轮对话。以下是使用方法:
@@ -1037,9 +1084,10 @@ rag.clear_cache(modes=["local"])
1037
  | **参数** | **类型** | **说明** | **默认值** |
1038
  |--------------|----------|-----------------|-------------|
1039
  | **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
1040
- | **kv_storage** | `str` | 文档和文本块的存储类型。支持的类型:`JsonKVStorage`、`OracleKVStorage` | `JsonKVStorage` |
1041
- | **vector_storage** | `str` | 嵌入向量的存储类型。支持的类型:`NanoVectorDBStorage`、`OracleVectorDBStorage` | `NanoVectorDBStorage` |
1042
- | **graph_storage** | `str` | 图边和节点的存储类型。支持的类型:`NetworkXStorage`、`Neo4JStorage`、`OracleGraphStorage` | `NetworkXStorage` |
 
1043
  | **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
1044
  | **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
1045
  | **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
 
11
  - [X] [2024.12.31]🎯📢LightRAG现在支持[通过文档ID删除](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
12
  - [X] [2024.11.25]🎯📢LightRAG现在支持无缝集成[自定义知识图谱](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg),使用户能够用自己的领域专业知识增强系统。
13
  - [X] [2024.11.19]🎯📢LightRAG的综合指南现已在[LearnOpenCV](https://learnopencv.com/lightrag)上发布。非常感谢博客作者。
 
14
  - [X] [2024.11.11]🎯📢LightRAG现在支持[通过实体名称删除实体](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete)。
15
  - [X] [2024.11.09]🎯📢推出[LightRAG Gui](https://lightrag-gui.streamlit.app),允许您插入、查询、可视化和下载LightRAG知识。
16
  - [X] [2024.11.04]🎯📢现在您可以[使用Neo4J进行存储](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage)。
 
409
 
410
  </details>
411
 
412
+ ### Token统计功能
413
+ <details>
414
+ <summary> <b>概述和使用</b> </summary>
415
+
416
+ LightRAG提供了TokenTracker工具来跟踪和管理大模型的token消耗。这个功能对于控制API成本和优化性能特别有用。
417
+
418
+ #### 使用方法
419
+
420
+ ```python
421
+ from lightrag.utils import TokenTracker
422
+
423
+ # 创建TokenTracker实例
424
+ token_tracker = TokenTracker()
425
+
426
+ # 方法1:使用上下文管理器(推荐)
427
+ # 适用于需要自动跟踪token使用的场景
428
+ with token_tracker:
429
+ result1 = await llm_model_func("你的问题1")
430
+ result2 = await llm_model_func("你的问题2")
431
+
432
+ # 方法2:手动添加token使用记录
433
+ # 适用于需要更精细控制token统计的场景
434
+ token_tracker.reset()
435
+
436
+ rag.insert()
437
+
438
+ rag.query("你的问题1", param=QueryParam(mode="naive"))
439
+ rag.query("你的问题2", param=QueryParam(mode="mix"))
440
+
441
+ # 显示总token使用量(包含插入和查询操作)
442
+ print("Token usage:", token_tracker.get_usage())
443
+ ```
444
+
445
+ #### 使用建议
446
+ - 在长会话或批量操作中使用上下文管理器,可以自动跟踪所有token消耗
447
+ - 对于需要分段统计的场景,使用手动模式并适时调用reset()
448
+ - 定期检查token使用情况,有助于及时发现异常消耗
449
+ - 在开发测试阶段积极使用此功能,以便优化生产环境的成本
450
+
451
+ #### 实际应用示例
452
+ 您可以参考以下示例来实现token统计:
453
+ - `examples/lightrag_gemini_track_token_demo.py`:使用Google Gemini模型的token统计示例
454
+ - `examples/lightrag_siliconcloud_track_token_demo.py`:使用SiliconCloud模型的token统计示例
455
+
456
+ 这些示例展示了如何在不同模型和场景下有效地使用TokenTracker功能。
457
+
458
+ </details>
459
+
460
  ### 对话历史
461
 
462
  LightRAG现在通过对话历史功能支持多轮对话。以下是使用方法:
 
1084
  | **参数** | **类型** | **说明** | **默认值** |
1085
  |--------------|----------|-----------------|-------------|
1086
  | **working_dir** | `str` | 存储缓存的目录 | `lightrag_cache+timestamp` |
1087
+ | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
1088
+ | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
1089
+ | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
1090
+ | **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
1091
  | **chunk_token_size** | `int` | 拆分文档时每个块的最大令牌大小 | `1200` |
1092
  | **chunk_overlap_token_size** | `int` | 拆分文档时两个块之间的重叠令牌大小 | `100` |
1093
  | **tiktoken_model_name** | `str` | 用于计算令牌数的Tiktoken编码器的模型名称 | `gpt-4o-mini` |
README.md CHANGED
@@ -41,7 +41,6 @@
41
  - [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
42
  - [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
43
  - [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
44
- - [X] [2024.11.12]🎯📢LightRAG now supports [Oracle Database 23ai for all storage types (KV, vector, and graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py).
45
  - [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
46
  - [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
47
  - [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
@@ -443,6 +442,55 @@ if __name__ == "__main__":
443
 
444
  </details>
445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  ### Conversation History Support
447
 
448
 
@@ -607,7 +655,7 @@ The `apipeline_enqueue_documents` and `apipeline_process_enqueue_documents` func
607
 
608
  This is useful for scenarios where you want to process documents in the background while still allowing the main thread to continue executing.
609
 
610
- And using a routine to process news documents.
611
 
612
  ```python
613
  rag = LightRAG(..)
@@ -1096,9 +1144,10 @@ Valid modes are:
1096
  | **Parameter** | **Type** | **Explanation** | **Default** |
1097
  |--------------|----------|-----------------|-------------|
1098
  | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
1099
- | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` |
1100
- | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` |
1101
- | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` |
 
1102
  | **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
1103
  | **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
1104
  | **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
 
41
  - [X] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
42
  - [X] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
43
  - [X] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
 
44
  - [X] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
45
  - [X] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
46
  - [X] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
 
442
 
443
  </details>
444
 
445
+ ### Token Usage Tracking
446
+
447
+ <details>
448
+ <summary> <b>Overview and Usage</b> </summary>
449
+
450
+ LightRAG provides a TokenTracker tool to monitor and manage token consumption by large language models. This feature is particularly useful for controlling API costs and optimizing performance.
451
+
452
+ #### Usage
453
+
454
+ ```python
455
+ from lightrag.utils import TokenTracker
456
+
457
+ # Create TokenTracker instance
458
+ token_tracker = TokenTracker()
459
+
460
+ # Method 1: Using context manager (Recommended)
461
+ # Suitable for scenarios requiring automatic token usage tracking
462
+ with token_tracker:
463
+ result1 = await llm_model_func("your question 1")
464
+ result2 = await llm_model_func("your question 2")
465
+
466
+ # Method 2: Manually adding token usage records
467
+ # Suitable for scenarios requiring more granular control over token statistics
468
+ token_tracker.reset()
469
+
470
+ rag.insert()
471
+
472
+ rag.query("your question 1", param=QueryParam(mode="naive"))
473
+ rag.query("your question 2", param=QueryParam(mode="mix"))
474
+
475
+ # Display total token usage (including insert and query operations)
476
+ print("Token usage:", token_tracker.get_usage())
477
+ ```
478
+
479
+ #### Usage Tips
480
+ - Use context managers for long sessions or batch operations to automatically track all token consumption
481
+ - For scenarios requiring segmented statistics, use manual mode and call reset() when appropriate
482
+ - Regular checking of token usage helps detect abnormal consumption early
483
+ - Actively use this feature during development and testing to optimize production costs
484
+
485
+ #### Practical Examples
486
+ You can refer to these examples for implementing token tracking:
487
+ - `examples/lightrag_gemini_track_token_demo.py`: Token tracking example using Google Gemini model
488
+ - `examples/lightrag_siliconcloud_track_token_demo.py`: Token tracking example using SiliconCloud model
489
+
490
+ These examples demonstrate how to effectively use the TokenTracker feature with different models and scenarios.
491
+
492
+ </details>
493
+
494
  ### Conversation History Support
495
 
496
 
 
655
 
656
  This is useful for scenarios where you want to process documents in the background while still allowing the main thread to continue executing.
657
 
658
+ And using a routine to process new documents.
659
 
660
  ```python
661
  rag = LightRAG(..)
 
1144
  | **Parameter** | **Type** | **Explanation** | **Default** |
1145
  |--------------|----------|-----------------|-------------|
1146
  | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
1147
+ | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
1148
+ | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
1149
+ | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
1150
+ | **doc_status_storage** | `str` | Storage type for documents process status. Supported types: `JsonDocStatusStorage`,`PGDocStatusStorage`,`MongoDocStatusStorage` | `JsonDocStatusStorage` |
1151
  | **chunk_token_size** | `int` | Maximum token size per chunk when splitting documents | `1200` |
1152
  | **chunk_overlap_token_size** | `int` | Overlap token size between two chunks when splitting documents | `100` |
1153
  | **tiktoken_model_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` |
config.ini.example CHANGED
@@ -13,23 +13,6 @@ uri=redis://localhost:6379/1
13
  [qdrant]
14
  uri = http://localhost:16333
15
 
16
- [oracle]
17
- dsn = localhost:1521/XEPDB1
18
- user = your_username
19
- password = your_password
20
- config_dir = /path/to/oracle/config
21
- wallet_location = /path/to/wallet # 可选
22
- wallet_password = your_wallet_password # 可选
23
- workspace = default # 可选,默认为default
24
-
25
- [tidb]
26
- host = localhost
27
- port = 4000
28
- user = your_username
29
- password = your_password
30
- database = your_database
31
- workspace = default # 可选,默认为default
32
-
33
  [postgres]
34
  host = localhost
35
  port = 5432
 
13
  [qdrant]
14
  uri = http://localhost:16333
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  [postgres]
17
  host = localhost
18
  port = 5432
env.example CHANGED
@@ -4,11 +4,9 @@
4
  # HOST=0.0.0.0
5
  # PORT=9621
6
  # WORKERS=2
7
- ### separating data from difference Lightrag instances
8
- # NAMESPACE_PREFIX=lightrag
9
- ### Max nodes return from grap retrieval
10
- # MAX_GRAPH_NODES=1000
11
  # CORS_ORIGINS=http://localhost:3000,http://localhost:8080
 
 
12
 
13
  ### Optional SSL Configuration
14
  # SSL=true
@@ -22,6 +20,9 @@
22
  ### Ollama Emulating Model Tag
23
  # OLLAMA_EMULATING_MODEL_TAG=latest
24
 
 
 
 
25
  ### Logging level
26
  # LOG_LEVEL=INFO
27
  # VERBOSE=False
@@ -110,24 +111,14 @@ LIGHTRAG_VECTOR_STORAGE=NanoVectorDBStorage
110
  LIGHTRAG_GRAPH_STORAGE=NetworkXStorage
111
  LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage
112
 
113
- ### Oracle Database Configuration
114
- ORACLE_DSN=localhost:1521/XEPDB1
115
- ORACLE_USER=your_username
116
- ORACLE_PASSWORD='your_password'
117
- ORACLE_CONFIG_DIR=/path/to/oracle/config
118
- #ORACLE_WALLET_LOCATION=/path/to/wallet
119
- #ORACLE_WALLET_PASSWORD='your_password'
120
- ### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
121
- #ORACLE_WORKSPACE=default
122
-
123
- ### TiDB Configuration
124
- TIDB_HOST=localhost
125
- TIDB_PORT=4000
126
- TIDB_USER=your_username
127
- TIDB_PASSWORD='your_password'
128
- TIDB_DATABASE=your_database
129
- ### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
130
- #TIDB_WORKSPACE=default
131
 
132
  ### PostgreSQL Configuration
133
  POSTGRES_HOST=localhost
@@ -135,8 +126,8 @@ POSTGRES_PORT=5432
135
  POSTGRES_USER=your_username
136
  POSTGRES_PASSWORD='your_password'
137
  POSTGRES_DATABASE=your_database
138
- ### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
139
- #POSTGRES_WORKSPACE=default
140
 
141
  ### Independent AGM Configuration(not for AMG embedded in PostreSQL)
142
  AGE_POSTGRES_DB=
@@ -145,8 +136,8 @@ AGE_POSTGRES_PASSWORD=
145
  AGE_POSTGRES_HOST=
146
  # AGE_POSTGRES_PORT=8529
147
 
148
- ### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
149
  # AGE Graph Name(apply to PostgreSQL and independent AGM)
 
150
  # AGE_GRAPH_NAME=lightrag
151
 
152
  ### Neo4j Configuration
@@ -157,7 +148,7 @@ NEO4J_PASSWORD='your_password'
157
  ### MongoDB Configuration
158
  MONGO_URI=mongodb://root:root@localhost:27017/
159
  MONGO_DATABASE=LightRAG
160
- ### separating all data from difference Lightrag instances(deprecating, use NAMESPACE_PREFIX in future)
161
  # MONGODB_GRAPH=false
162
 
163
  ### Milvus Configuration
@@ -177,7 +168,9 @@ REDIS_URI=redis://localhost:6379
177
  ### For JWT Auth
178
  # AUTH_ACCOUNTS='admin:admin123,user1:pass456'
179
  # TOKEN_SECRET=Your-Key-For-LightRAG-API-Server
180
- # TOKEN_EXPIRE_HOURS=4
 
 
181
 
182
  ### API-Key to access LightRAG Server API
183
  # LIGHTRAG_API_KEY=your-secure-api-key-here
 
4
  # HOST=0.0.0.0
5
  # PORT=9621
6
  # WORKERS=2
 
 
 
 
7
  # CORS_ORIGINS=http://localhost:3000,http://localhost:8080
8
+ WEBUI_TITLE='Graph RAG Engine'
9
+ WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
10
 
11
  ### Optional SSL Configuration
12
  # SSL=true
 
20
  ### Ollama Emulating Model Tag
21
  # OLLAMA_EMULATING_MODEL_TAG=latest
22
 
23
+ ### Max nodes return from grap retrieval
24
+ # MAX_GRAPH_NODES=1000
25
+
26
  ### Logging level
27
  # LOG_LEVEL=INFO
28
  # VERBOSE=False
 
111
  LIGHTRAG_GRAPH_STORAGE=NetworkXStorage
112
  LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage
113
 
114
+ ### TiDB Configuration (Deprecated)
115
+ # TIDB_HOST=localhost
116
+ # TIDB_PORT=4000
117
+ # TIDB_USER=your_username
118
+ # TIDB_PASSWORD='your_password'
119
+ # TIDB_DATABASE=your_database
120
+ ### separating all data from difference Lightrag instances(deprecating)
121
+ # TIDB_WORKSPACE=default
 
 
 
 
 
 
 
 
 
 
122
 
123
  ### PostgreSQL Configuration
124
  POSTGRES_HOST=localhost
 
126
  POSTGRES_USER=your_username
127
  POSTGRES_PASSWORD='your_password'
128
  POSTGRES_DATABASE=your_database
129
+ ### separating all data from difference Lightrag instances(deprecating)
130
+ # POSTGRES_WORKSPACE=default
131
 
132
  ### Independent AGM Configuration(not for AMG embedded in PostreSQL)
133
  AGE_POSTGRES_DB=
 
136
  AGE_POSTGRES_HOST=
137
  # AGE_POSTGRES_PORT=8529
138
 
 
139
  # AGE Graph Name(apply to PostgreSQL and independent AGM)
140
+ ### AGE_GRAPH_NAME is precated
141
  # AGE_GRAPH_NAME=lightrag
142
 
143
  ### Neo4j Configuration
 
148
  ### MongoDB Configuration
149
  MONGO_URI=mongodb://root:root@localhost:27017/
150
  MONGO_DATABASE=LightRAG
151
+ ### separating all data from difference Lightrag instances(deprecating)
152
  # MONGODB_GRAPH=false
153
 
154
  ### Milvus Configuration
 
168
  ### For JWT Auth
169
  # AUTH_ACCOUNTS='admin:admin123,user1:pass456'
170
  # TOKEN_SECRET=Your-Key-For-LightRAG-API-Server
171
+ # TOKEN_EXPIRE_HOURS=48
172
+ # GUEST_TOKEN_EXPIRE_HOURS=24
173
+ # JWT_ALGORITHM=HS256
174
 
175
  ### API-Key to access LightRAG Server API
176
  # LIGHTRAG_API_KEY=your-secure-api-key-here
examples/lightrag_api_ollama_demo.py DELETED
@@ -1,188 +0,0 @@
1
- from fastapi import FastAPI, HTTPException, File, UploadFile
2
- from contextlib import asynccontextmanager
3
- from pydantic import BaseModel
4
- import os
5
- from lightrag import LightRAG, QueryParam
6
- from lightrag.llm.ollama import ollama_embed, ollama_model_complete
7
- from lightrag.utils import EmbeddingFunc
8
- from typing import Optional
9
- import asyncio
10
- import nest_asyncio
11
- import aiofiles
12
- from lightrag.kg.shared_storage import initialize_pipeline_status
13
-
14
- # Apply nest_asyncio to solve event loop issues
15
- nest_asyncio.apply()
16
-
17
- DEFAULT_RAG_DIR = "index_default"
18
-
19
- DEFAULT_INPUT_FILE = "book.txt"
20
- INPUT_FILE = os.environ.get("INPUT_FILE", f"{DEFAULT_INPUT_FILE}")
21
- print(f"INPUT_FILE: {INPUT_FILE}")
22
-
23
- # Configure working directory
24
- WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
25
- print(f"WORKING_DIR: {WORKING_DIR}")
26
-
27
-
28
- if not os.path.exists(WORKING_DIR):
29
- os.mkdir(WORKING_DIR)
30
-
31
-
32
- async def init():
33
- rag = LightRAG(
34
- working_dir=WORKING_DIR,
35
- llm_model_func=ollama_model_complete,
36
- llm_model_name="gemma2:9b",
37
- llm_model_max_async=4,
38
- llm_model_max_token_size=8192,
39
- llm_model_kwargs={
40
- "host": "http://localhost:11434",
41
- "options": {"num_ctx": 8192},
42
- },
43
- embedding_func=EmbeddingFunc(
44
- embedding_dim=768,
45
- max_token_size=8192,
46
- func=lambda texts: ollama_embed(
47
- texts, embed_model="nomic-embed-text", host="http://localhost:11434"
48
- ),
49
- ),
50
- )
51
-
52
- # Add initialization code
53
- await rag.initialize_storages()
54
- await initialize_pipeline_status()
55
-
56
- return rag
57
-
58
-
59
- @asynccontextmanager
60
- async def lifespan(app: FastAPI):
61
- global rag
62
- rag = await init()
63
- print("done!")
64
- yield
65
-
66
-
67
- app = FastAPI(
68
- title="LightRAG API", description="API for RAG operations", lifespan=lifespan
69
- )
70
-
71
-
72
- # Data models
73
- class QueryRequest(BaseModel):
74
- query: str
75
- mode: str = "hybrid"
76
- only_need_context: bool = False
77
-
78
-
79
- class InsertRequest(BaseModel):
80
- text: str
81
-
82
-
83
- class Response(BaseModel):
84
- status: str
85
- data: Optional[str] = None
86
- message: Optional[str] = None
87
-
88
-
89
- # API routes
90
- @app.post("/query", response_model=Response)
91
- async def query_endpoint(request: QueryRequest):
92
- try:
93
- loop = asyncio.get_event_loop()
94
- result = await loop.run_in_executor(
95
- None,
96
- lambda: rag.query(
97
- request.query,
98
- param=QueryParam(
99
- mode=request.mode, only_need_context=request.only_need_context
100
- ),
101
- ),
102
- )
103
- return Response(status="success", data=result)
104
- except Exception as e:
105
- raise HTTPException(status_code=500, detail=str(e))
106
-
107
-
108
- # insert by text
109
- @app.post("/insert", response_model=Response)
110
- async def insert_endpoint(request: InsertRequest):
111
- try:
112
- loop = asyncio.get_event_loop()
113
- await loop.run_in_executor(None, lambda: rag.insert(request.text))
114
- return Response(status="success", message="Text inserted successfully")
115
- except Exception as e:
116
- raise HTTPException(status_code=500, detail=str(e))
117
-
118
-
119
- # insert by file in payload
120
- @app.post("/insert_file", response_model=Response)
121
- async def insert_file(file: UploadFile = File(...)):
122
- try:
123
- file_content = await file.read()
124
- # Read file content
125
- try:
126
- content = file_content.decode("utf-8")
127
- except UnicodeDecodeError:
128
- # If UTF-8 decoding fails, try other encodings
129
- content = file_content.decode("gbk")
130
- # Insert file content
131
- loop = asyncio.get_event_loop()
132
- await loop.run_in_executor(None, lambda: rag.insert(content))
133
-
134
- return Response(
135
- status="success",
136
- message=f"File content from {file.filename} inserted successfully",
137
- )
138
- except Exception as e:
139
- raise HTTPException(status_code=500, detail=str(e))
140
-
141
-
142
- # insert by local default file
143
- @app.post("/insert_default_file", response_model=Response)
144
- @app.get("/insert_default_file", response_model=Response)
145
- async def insert_default_file():
146
- try:
147
- # Read file content from book.txt
148
- async with aiofiles.open(INPUT_FILE, "r", encoding="utf-8") as file:
149
- content = await file.read()
150
- print(f"read input file {INPUT_FILE} successfully")
151
- # Insert file content
152
- loop = asyncio.get_event_loop()
153
- await loop.run_in_executor(None, lambda: rag.insert(content))
154
-
155
- return Response(
156
- status="success",
157
- message=f"File content from {INPUT_FILE} inserted successfully",
158
- )
159
- except Exception as e:
160
- raise HTTPException(status_code=500, detail=str(e))
161
-
162
-
163
- @app.get("/health")
164
- async def health_check():
165
- return {"status": "healthy"}
166
-
167
-
168
- if __name__ == "__main__":
169
- import uvicorn
170
-
171
- uvicorn.run(app, host="0.0.0.0", port=8020)
172
-
173
- # Usage example
174
- # To run the server, use the following command in your terminal:
175
- # python lightrag_api_openai_compatible_demo.py
176
-
177
- # Example requests:
178
- # 1. Query:
179
- # curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
180
-
181
- # 2. Insert text:
182
- # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
183
-
184
- # 3. Insert file:
185
- # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
186
-
187
- # 4. Health check:
188
- # curl -X GET "http://127.0.0.1:8020/health"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_api_openai_compatible_demo.py DELETED
@@ -1,204 +0,0 @@
1
- from fastapi import FastAPI, HTTPException, File, UploadFile
2
- from contextlib import asynccontextmanager
3
- from pydantic import BaseModel
4
- import os
5
- from lightrag import LightRAG, QueryParam
6
- from lightrag.llm.openai import openai_complete_if_cache, openai_embed
7
- from lightrag.utils import EmbeddingFunc
8
- import numpy as np
9
- from typing import Optional
10
- import asyncio
11
- import nest_asyncio
12
- from lightrag.kg.shared_storage import initialize_pipeline_status
13
-
14
- # Apply nest_asyncio to solve event loop issues
15
- nest_asyncio.apply()
16
-
17
- DEFAULT_RAG_DIR = "index_default"
18
- app = FastAPI(title="LightRAG API", description="API for RAG operations")
19
-
20
- # Configure working directory
21
- WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
22
- print(f"WORKING_DIR: {WORKING_DIR}")
23
- LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini")
24
- print(f"LLM_MODEL: {LLM_MODEL}")
25
- EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
26
- print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
27
- EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
28
- print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
29
- BASE_URL = os.environ.get("BASE_URL", "https://api.openai.com/v1")
30
- print(f"BASE_URL: {BASE_URL}")
31
- API_KEY = os.environ.get("API_KEY", "xxxxxxxx")
32
- print(f"API_KEY: {API_KEY}")
33
-
34
- if not os.path.exists(WORKING_DIR):
35
- os.mkdir(WORKING_DIR)
36
-
37
-
38
- # LLM model function
39
-
40
-
41
- async def llm_model_func(
42
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
43
- ) -> str:
44
- return await openai_complete_if_cache(
45
- model=LLM_MODEL,
46
- prompt=prompt,
47
- system_prompt=system_prompt,
48
- history_messages=history_messages,
49
- base_url=BASE_URL,
50
- api_key=API_KEY,
51
- **kwargs,
52
- )
53
-
54
-
55
- # Embedding function
56
-
57
-
58
- async def embedding_func(texts: list[str]) -> np.ndarray:
59
- return await openai_embed(
60
- texts=texts,
61
- model=EMBEDDING_MODEL,
62
- base_url=BASE_URL,
63
- api_key=API_KEY,
64
- )
65
-
66
-
67
- async def get_embedding_dim():
68
- test_text = ["This is a test sentence."]
69
- embedding = await embedding_func(test_text)
70
- embedding_dim = embedding.shape[1]
71
- print(f"{embedding_dim=}")
72
- return embedding_dim
73
-
74
-
75
- # Initialize RAG instance
76
- async def init():
77
- embedding_dimension = await get_embedding_dim()
78
-
79
- rag = LightRAG(
80
- working_dir=WORKING_DIR,
81
- llm_model_func=llm_model_func,
82
- embedding_func=EmbeddingFunc(
83
- embedding_dim=embedding_dimension,
84
- max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
85
- func=embedding_func,
86
- ),
87
- )
88
-
89
- await rag.initialize_storages()
90
- await initialize_pipeline_status()
91
-
92
- return rag
93
-
94
-
95
- @asynccontextmanager
96
- async def lifespan(app: FastAPI):
97
- global rag
98
- rag = await init()
99
- print("done!")
100
- yield
101
-
102
-
103
- app = FastAPI(
104
- title="LightRAG API", description="API for RAG operations", lifespan=lifespan
105
- )
106
-
107
- # Data models
108
-
109
-
110
- class QueryRequest(BaseModel):
111
- query: str
112
- mode: str = "hybrid"
113
- only_need_context: bool = False
114
-
115
-
116
- class InsertRequest(BaseModel):
117
- text: str
118
-
119
-
120
- class Response(BaseModel):
121
- status: str
122
- data: Optional[str] = None
123
- message: Optional[str] = None
124
-
125
-
126
- # API routes
127
-
128
-
129
- @app.post("/query", response_model=Response)
130
- async def query_endpoint(request: QueryRequest):
131
- try:
132
- loop = asyncio.get_event_loop()
133
- result = await loop.run_in_executor(
134
- None,
135
- lambda: rag.query(
136
- request.query,
137
- param=QueryParam(
138
- mode=request.mode, only_need_context=request.only_need_context
139
- ),
140
- ),
141
- )
142
- return Response(status="success", data=result)
143
- except Exception as e:
144
- raise HTTPException(status_code=500, detail=str(e))
145
-
146
-
147
- @app.post("/insert", response_model=Response)
148
- async def insert_endpoint(request: InsertRequest):
149
- try:
150
- loop = asyncio.get_event_loop()
151
- await loop.run_in_executor(None, lambda: rag.insert(request.text))
152
- return Response(status="success", message="Text inserted successfully")
153
- except Exception as e:
154
- raise HTTPException(status_code=500, detail=str(e))
155
-
156
-
157
- @app.post("/insert_file", response_model=Response)
158
- async def insert_file(file: UploadFile = File(...)):
159
- try:
160
- file_content = await file.read()
161
- # Read file content
162
- try:
163
- content = file_content.decode("utf-8")
164
- except UnicodeDecodeError:
165
- # If UTF-8 decoding fails, try other encodings
166
- content = file_content.decode("gbk")
167
- # Insert file content
168
- loop = asyncio.get_event_loop()
169
- await loop.run_in_executor(None, lambda: rag.insert(content))
170
-
171
- return Response(
172
- status="success",
173
- message=f"File content from {file.filename} inserted successfully",
174
- )
175
- except Exception as e:
176
- raise HTTPException(status_code=500, detail=str(e))
177
-
178
-
179
- @app.get("/health")
180
- async def health_check():
181
- return {"status": "healthy"}
182
-
183
-
184
- if __name__ == "__main__":
185
- import uvicorn
186
-
187
- uvicorn.run(app, host="0.0.0.0", port=8020)
188
-
189
- # Usage example
190
- # To run the server, use the following command in your terminal:
191
- # python lightrag_api_openai_compatible_demo.py
192
-
193
- # Example requests:
194
- # 1. Query:
195
- # curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
196
-
197
- # 2. Insert text:
198
- # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
199
-
200
- # 3. Insert file:
201
- # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
202
-
203
- # 4. Health check:
204
- # curl -X GET "http://127.0.0.1:8020/health"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_api_oracle_demo.py DELETED
@@ -1,267 +0,0 @@
1
- from fastapi import FastAPI, HTTPException, File, UploadFile
2
- from fastapi import Query
3
- from contextlib import asynccontextmanager
4
- from pydantic import BaseModel
5
- from typing import Optional, Any
6
-
7
- import sys
8
- import os
9
-
10
-
11
- from pathlib import Path
12
-
13
- import asyncio
14
- import nest_asyncio
15
- from lightrag import LightRAG, QueryParam
16
- from lightrag.llm.openai import openai_complete_if_cache, openai_embed
17
- from lightrag.utils import EmbeddingFunc
18
- import numpy as np
19
- from lightrag.kg.shared_storage import initialize_pipeline_status
20
-
21
-
22
- print(os.getcwd())
23
- script_directory = Path(__file__).resolve().parent.parent
24
- sys.path.append(os.path.abspath(script_directory))
25
-
26
-
27
- # Apply nest_asyncio to solve event loop issues
28
- nest_asyncio.apply()
29
-
30
- DEFAULT_RAG_DIR = "index_default"
31
-
32
-
33
- # We use OpenAI compatible API to call LLM on Oracle Cloud
34
- # More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
35
- BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
36
- APIKEY = "ocigenerativeai"
37
-
38
- # Configure working directory
39
- WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
40
- print(f"WORKING_DIR: {WORKING_DIR}")
41
- LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024")
42
- print(f"LLM_MODEL: {LLM_MODEL}")
43
- EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0")
44
- print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
45
- EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512))
46
- print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
47
-
48
- if not os.path.exists(WORKING_DIR):
49
- os.mkdir(WORKING_DIR)
50
-
51
- os.environ["ORACLE_USER"] = ""
52
- os.environ["ORACLE_PASSWORD"] = ""
53
- os.environ["ORACLE_DSN"] = ""
54
- os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
55
- os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
56
- os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
57
- os.environ["ORACLE_WORKSPACE"] = "company"
58
-
59
-
60
- async def llm_model_func(
61
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
62
- ) -> str:
63
- return await openai_complete_if_cache(
64
- LLM_MODEL,
65
- prompt,
66
- system_prompt=system_prompt,
67
- history_messages=history_messages,
68
- api_key=APIKEY,
69
- base_url=BASE_URL,
70
- **kwargs,
71
- )
72
-
73
-
74
- async def embedding_func(texts: list[str]) -> np.ndarray:
75
- return await openai_embed(
76
- texts,
77
- model=EMBEDDING_MODEL,
78
- api_key=APIKEY,
79
- base_url=BASE_URL,
80
- )
81
-
82
-
83
- async def get_embedding_dim():
84
- test_text = ["This is a test sentence."]
85
- embedding = await embedding_func(test_text)
86
- embedding_dim = embedding.shape[1]
87
- return embedding_dim
88
-
89
-
90
- async def init():
91
- # Detect embedding dimension
92
- embedding_dimension = await get_embedding_dim()
93
- print(f"Detected embedding dimension: {embedding_dimension}")
94
- # Create Oracle DB connection
95
- # The `config` parameter is the connection configuration of Oracle DB
96
- # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
97
- # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
98
- # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
99
-
100
- # Initialize LightRAG
101
- # We use Oracle DB as the KV/vector/graph storage
102
- rag = LightRAG(
103
- enable_llm_cache=False,
104
- working_dir=WORKING_DIR,
105
- chunk_token_size=512,
106
- llm_model_func=llm_model_func,
107
- embedding_func=EmbeddingFunc(
108
- embedding_dim=embedding_dimension,
109
- max_token_size=512,
110
- func=embedding_func,
111
- ),
112
- graph_storage="OracleGraphStorage",
113
- kv_storage="OracleKVStorage",
114
- vector_storage="OracleVectorDBStorage",
115
- )
116
-
117
- await rag.initialize_storages()
118
- await initialize_pipeline_status()
119
-
120
- return rag
121
-
122
-
123
- # Extract and Insert into LightRAG storage
124
- # with open("./dickens/book.txt", "r", encoding="utf-8") as f:
125
- # await rag.ainsert(f.read())
126
-
127
- # # Perform search in different modes
128
- # modes = ["naive", "local", "global", "hybrid"]
129
- # for mode in modes:
130
- # print("="*20, mode, "="*20)
131
- # print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode)))
132
- # print("-"*100, "\n")
133
-
134
- # Data models
135
-
136
-
137
- class QueryRequest(BaseModel):
138
- query: str
139
- mode: str = "hybrid"
140
- only_need_context: bool = False
141
- only_need_prompt: bool = False
142
-
143
-
144
- class DataRequest(BaseModel):
145
- limit: int = 100
146
-
147
-
148
- class InsertRequest(BaseModel):
149
- text: str
150
-
151
-
152
- class Response(BaseModel):
153
- status: str
154
- data: Optional[Any] = None
155
- message: Optional[str] = None
156
-
157
-
158
- # API routes
159
-
160
- rag = None
161
-
162
-
163
- @asynccontextmanager
164
- async def lifespan(app: FastAPI):
165
- global rag
166
- rag = await init()
167
- print("done!")
168
- yield
169
-
170
-
171
- app = FastAPI(
172
- title="LightRAG API", description="API for RAG operations", lifespan=lifespan
173
- )
174
-
175
-
176
- @app.post("/query", response_model=Response)
177
- async def query_endpoint(request: QueryRequest):
178
- # try:
179
- # loop = asyncio.get_event_loop()
180
- if request.mode == "naive":
181
- top_k = 3
182
- else:
183
- top_k = 60
184
- result = await rag.aquery(
185
- request.query,
186
- param=QueryParam(
187
- mode=request.mode,
188
- only_need_context=request.only_need_context,
189
- only_need_prompt=request.only_need_prompt,
190
- top_k=top_k,
191
- ),
192
- )
193
- return Response(status="success", data=result)
194
- # except Exception as e:
195
- # raise HTTPException(status_code=500, detail=str(e))
196
-
197
-
198
- @app.get("/data", response_model=Response)
199
- async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)):
200
- if type == "nodes":
201
- result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit)
202
- elif type == "edges":
203
- result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit)
204
- elif type == "statistics":
205
- result = await rag.chunk_entity_relation_graph.get_statistics()
206
- return Response(status="success", data=result)
207
-
208
-
209
- @app.post("/insert", response_model=Response)
210
- async def insert_endpoint(request: InsertRequest):
211
- try:
212
- loop = asyncio.get_event_loop()
213
- await loop.run_in_executor(None, lambda: rag.insert(request.text))
214
- return Response(status="success", message="Text inserted successfully")
215
- except Exception as e:
216
- raise HTTPException(status_code=500, detail=str(e))
217
-
218
-
219
- @app.post("/insert_file", response_model=Response)
220
- async def insert_file(file: UploadFile = File(...)):
221
- try:
222
- file_content = await file.read()
223
- # Read file content
224
- try:
225
- content = file_content.decode("utf-8")
226
- except UnicodeDecodeError:
227
- # If UTF-8 decoding fails, try other encodings
228
- content = file_content.decode("gbk")
229
- # Insert file content
230
- loop = asyncio.get_event_loop()
231
- await loop.run_in_executor(None, lambda: rag.insert(content))
232
-
233
- return Response(
234
- status="success",
235
- message=f"File content from {file.filename} inserted successfully",
236
- )
237
- except Exception as e:
238
- raise HTTPException(status_code=500, detail=str(e))
239
-
240
-
241
- @app.get("/health")
242
- async def health_check():
243
- return {"status": "healthy"}
244
-
245
-
246
- if __name__ == "__main__":
247
- import uvicorn
248
-
249
- uvicorn.run(app, host="127.0.0.1", port=8020)
250
-
251
- # Usage example
252
- # To run the server, use the following command in your terminal:
253
- # python lightrag_api_openai_compatible_demo.py
254
-
255
- # Example requests:
256
- # 1. Query:
257
- # curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
258
-
259
- # 2. Insert text:
260
- # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
261
-
262
- # 3. Insert file:
263
- # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: multipart/form-data" -F "file=@path/to/your/file.txt"
264
-
265
-
266
- # 4. Health check:
267
- # curl -X GET "http://127.0.0.1:8020/health"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_ollama_gremlin_demo.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import asyncio
2
  import inspect
3
  import os
 
1
+ ##############################################
2
+ # Gremlin storage implementation is deprecated
3
+ ##############################################
4
+
5
  import asyncio
6
  import inspect
7
  import os
examples/lightrag_oracle_demo.py DELETED
@@ -1,141 +0,0 @@
1
- import sys
2
- import os
3
- from pathlib import Path
4
- import asyncio
5
- from lightrag import LightRAG, QueryParam
6
- from lightrag.llm.openai import openai_complete_if_cache, openai_embed
7
- from lightrag.utils import EmbeddingFunc
8
- import numpy as np
9
- from lightrag.kg.shared_storage import initialize_pipeline_status
10
-
11
- print(os.getcwd())
12
- script_directory = Path(__file__).resolve().parent.parent
13
- sys.path.append(os.path.abspath(script_directory))
14
-
15
- WORKING_DIR = "./dickens"
16
-
17
- # We use OpenAI compatible API to call LLM on Oracle Cloud
18
- # More docs here https://github.com/jin38324/OCI_GenAI_access_gateway
19
- BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
20
- APIKEY = "ocigenerativeai"
21
- CHATMODEL = "cohere.command-r-plus"
22
- EMBEDMODEL = "cohere.embed-multilingual-v3.0"
23
- CHUNK_TOKEN_SIZE = 1024
24
- MAX_TOKENS = 4000
25
-
26
- if not os.path.exists(WORKING_DIR):
27
- os.mkdir(WORKING_DIR)
28
-
29
- os.environ["ORACLE_USER"] = "username"
30
- os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx"
31
- os.environ["ORACLE_DSN"] = "xxxxxxx_medium"
32
- os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
33
- os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
34
- os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
35
- os.environ["ORACLE_WORKSPACE"] = "company"
36
-
37
-
38
- async def llm_model_func(
39
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
40
- ) -> str:
41
- return await openai_complete_if_cache(
42
- CHATMODEL,
43
- prompt,
44
- system_prompt=system_prompt,
45
- history_messages=history_messages,
46
- api_key=APIKEY,
47
- base_url=BASE_URL,
48
- **kwargs,
49
- )
50
-
51
-
52
- async def embedding_func(texts: list[str]) -> np.ndarray:
53
- return await openai_embed(
54
- texts,
55
- model=EMBEDMODEL,
56
- api_key=APIKEY,
57
- base_url=BASE_URL,
58
- )
59
-
60
-
61
- async def get_embedding_dim():
62
- test_text = ["This is a test sentence."]
63
- embedding = await embedding_func(test_text)
64
- embedding_dim = embedding.shape[1]
65
- return embedding_dim
66
-
67
-
68
- async def initialize_rag():
69
- # Detect embedding dimension
70
- embedding_dimension = await get_embedding_dim()
71
- print(f"Detected embedding dimension: {embedding_dimension}")
72
-
73
- # Initialize LightRAG
74
- # We use Oracle DB as the KV/vector/graph storage
75
- # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
76
- rag = LightRAG(
77
- # log_level="DEBUG",
78
- working_dir=WORKING_DIR,
79
- entity_extract_max_gleaning=1,
80
- enable_llm_cache=True,
81
- enable_llm_cache_for_entity_extract=True,
82
- embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
83
- chunk_token_size=CHUNK_TOKEN_SIZE,
84
- llm_model_max_token_size=MAX_TOKENS,
85
- llm_model_func=llm_model_func,
86
- embedding_func=EmbeddingFunc(
87
- embedding_dim=embedding_dimension,
88
- max_token_size=500,
89
- func=embedding_func,
90
- ),
91
- graph_storage="OracleGraphStorage",
92
- kv_storage="OracleKVStorage",
93
- vector_storage="OracleVectorDBStorage",
94
- addon_params={
95
- "example_number": 1,
96
- "language": "Simplfied Chinese",
97
- "entity_types": ["organization", "person", "geo", "event"],
98
- "insert_batch_size": 2,
99
- },
100
- )
101
- await rag.initialize_storages()
102
- await initialize_pipeline_status()
103
-
104
- return rag
105
-
106
-
107
- async def main():
108
- try:
109
- # Initialize RAG instance
110
- rag = await initialize_rag()
111
-
112
- # Extract and Insert into LightRAG storage
113
- with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
114
- all_text = f.read()
115
- texts = [x for x in all_text.split("\n") if x]
116
-
117
- # New mode use pipeline
118
- await rag.apipeline_enqueue_documents(texts)
119
- await rag.apipeline_process_enqueue_documents()
120
-
121
- # Old method use ainsert
122
- # await rag.ainsert(texts)
123
-
124
- # Perform search in different modes
125
- modes = ["naive", "local", "global", "hybrid"]
126
- for mode in modes:
127
- print("=" * 20, mode, "=" * 20)
128
- print(
129
- await rag.aquery(
130
- "What are the top themes in this story?",
131
- param=QueryParam(mode=mode),
132
- )
133
- )
134
- print("-" * 100, "\n")
135
-
136
- except Exception as e:
137
- print(f"An error occurred: {e}")
138
-
139
-
140
- if __name__ == "__main__":
141
- asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_tidb_demo.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import asyncio
2
  import os
3
 
 
1
+ ###########################################
2
+ # TiDB storage implementation is deprecated
3
+ ###########################################
4
+
5
  import asyncio
6
  import os
7
 
lightrag/api/README-zh.md CHANGED
@@ -291,11 +291,9 @@ LightRAG 使用 4 种类型的存储用于不同目的:
291
 
292
  ```
293
  JsonKVStorage JsonFile(默认)
294
- MongoKVStorage MogonDB
295
- RedisKVStorage Redis
296
- TiDBKVStorage TiDB
297
  PGKVStorage Postgres
298
- OracleKVStorage Oracle
 
299
  ```
300
 
301
  * GRAPH_STORAGE 支持的实现名称
@@ -303,25 +301,19 @@ OracleKVStorage Oracle
303
  ```
304
  NetworkXStorage NetworkX(默认)
305
  Neo4JStorage Neo4J
306
- MongoGraphStorage MongoDB
307
- TiDBGraphStorage TiDB
308
- AGEStorage AGE
309
- GremlinStorage Gremlin
310
  PGGraphStorage Postgres
311
- OracleGraphStorage Postgres
312
  ```
313
 
314
  * VECTOR_STORAGE 支持的实现名称
315
 
316
  ```
317
  NanoVectorDBStorage NanoVector(默认)
 
318
  MilvusVectorDBStorge Milvus
319
  ChromaVectorDBStorage Chroma
320
- TiDBVectorDBStorage TiDB
321
- PGVectorStorage Postgres
322
  FaissVectorDBStorage Faiss
323
  QdrantVectorDBStorage Qdrant
324
- OracleVectorDBStorage Oracle
325
  MongoVectorDBStorage MongoDB
326
  ```
327
 
 
291
 
292
  ```
293
  JsonKVStorage JsonFile(默认)
 
 
 
294
  PGKVStorage Postgres
295
+ RedisKVStorage Redis
296
+ MongoKVStorage MogonDB
297
  ```
298
 
299
  * GRAPH_STORAGE 支持的实现名称
 
301
  ```
302
  NetworkXStorage NetworkX(默认)
303
  Neo4JStorage Neo4J
 
 
 
 
304
  PGGraphStorage Postgres
305
+ AGEStorage AGE
306
  ```
307
 
308
  * VECTOR_STORAGE 支持的实现名称
309
 
310
  ```
311
  NanoVectorDBStorage NanoVector(默认)
312
+ PGVectorStorage Postgres
313
  MilvusVectorDBStorge Milvus
314
  ChromaVectorDBStorage Chroma
 
 
315
  FaissVectorDBStorage Faiss
316
  QdrantVectorDBStorage Qdrant
 
317
  MongoVectorDBStorage MongoDB
318
  ```
319
 
lightrag/api/README.md CHANGED
@@ -302,11 +302,9 @@ Each storage type have servals implementations:
302
 
303
  ```
304
  JsonKVStorage JsonFile(default)
305
- MongoKVStorage MogonDB
306
- RedisKVStorage Redis
307
- TiDBKVStorage TiDB
308
  PGKVStorage Postgres
309
- OracleKVStorage Oracle
 
310
  ```
311
 
312
  * GRAPH_STORAGE supported implement-name
@@ -314,25 +312,19 @@ OracleKVStorage Oracle
314
  ```
315
  NetworkXStorage NetworkX(defualt)
316
  Neo4JStorage Neo4J
317
- MongoGraphStorage MongoDB
318
- TiDBGraphStorage TiDB
319
- AGEStorage AGE
320
- GremlinStorage Gremlin
321
  PGGraphStorage Postgres
322
- OracleGraphStorage Postgres
323
  ```
324
 
325
  * VECTOR_STORAGE supported implement-name
326
 
327
  ```
328
  NanoVectorDBStorage NanoVector(default)
329
- MilvusVectorDBStorage Milvus
330
- ChromaVectorDBStorage Chroma
331
- TiDBVectorDBStorage TiDB
332
  PGVectorStorage Postgres
 
 
333
  FaissVectorDBStorage Faiss
334
  QdrantVectorDBStorage Qdrant
335
- OracleVectorDBStorage Oracle
336
  MongoVectorDBStorage MongoDB
337
  ```
338
 
 
302
 
303
  ```
304
  JsonKVStorage JsonFile(default)
 
 
 
305
  PGKVStorage Postgres
306
+ RedisKVStorage Redis
307
+ MongoKVStorage MogonDB
308
  ```
309
 
310
  * GRAPH_STORAGE supported implement-name
 
312
  ```
313
  NetworkXStorage NetworkX(defualt)
314
  Neo4JStorage Neo4J
 
 
 
 
315
  PGGraphStorage Postgres
316
+ AGEStorage AGE
317
  ```
318
 
319
  * VECTOR_STORAGE supported implement-name
320
 
321
  ```
322
  NanoVectorDBStorage NanoVector(default)
 
 
 
323
  PGVectorStorage Postgres
324
+ MilvusVectorDBStorge Milvus
325
+ ChromaVectorDBStorage Chroma
326
  FaissVectorDBStorage Faiss
327
  QdrantVectorDBStorage Qdrant
 
328
  MongoVectorDBStorage MongoDB
329
  ```
330
 
lightrag/api/__init__.py CHANGED
@@ -1 +1 @@
1
- __api_version__ = "1.2.8"
 
1
+ __api_version__ = "0136"
lightrag/api/auth.py CHANGED
@@ -1,9 +1,11 @@
1
- import os
2
  from datetime import datetime, timedelta
 
3
  import jwt
 
4
  from fastapi import HTTPException, status
5
  from pydantic import BaseModel
6
- from dotenv import load_dotenv
 
7
 
8
  # use the .env that is inside the current folder
9
  # allows to use different .env file for each lightrag instance
@@ -20,13 +22,12 @@ class TokenPayload(BaseModel):
20
 
21
  class AuthHandler:
22
  def __init__(self):
23
- self.secret = os.getenv("TOKEN_SECRET", "4f85ds4f56dsf46")
24
- self.algorithm = "HS256"
25
- self.expire_hours = int(os.getenv("TOKEN_EXPIRE_HOURS", 4))
26
- self.guest_expire_hours = int(os.getenv("GUEST_TOKEN_EXPIRE_HOURS", 2))
27
-
28
  self.accounts = {}
29
- auth_accounts = os.getenv("AUTH_ACCOUNTS")
30
  if auth_accounts:
31
  for account in auth_accounts.split(","):
32
  username, password = account.split(":", 1)
 
 
1
  from datetime import datetime, timedelta
2
+
3
  import jwt
4
+ from dotenv import load_dotenv
5
  from fastapi import HTTPException, status
6
  from pydantic import BaseModel
7
+
8
+ from .config import global_args
9
 
10
  # use the .env that is inside the current folder
11
  # allows to use different .env file for each lightrag instance
 
22
 
23
  class AuthHandler:
24
  def __init__(self):
25
+ self.secret = global_args.token_secret
26
+ self.algorithm = global_args.jwt_algorithm
27
+ self.expire_hours = global_args.token_expire_hours
28
+ self.guest_expire_hours = global_args.guest_token_expire_hours
 
29
  self.accounts = {}
30
+ auth_accounts = global_args.auth_accounts
31
  if auth_accounts:
32
  for account in auth_accounts.split(","):
33
  username, password = account.split(":", 1)
lightrag/api/config.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configs for the LightRAG API.
3
+ """
4
+
5
+ import os
6
+ import argparse
7
+ import logging
8
+ from dotenv import load_dotenv
9
+
10
+ # use the .env that is inside the current folder
11
+ # allows to use different .env file for each lightrag instance
12
+ # the OS environment variables take precedence over the .env file
13
+ load_dotenv(dotenv_path=".env", override=False)
14
+
15
+
16
+ class OllamaServerInfos:
17
+ # Constants for emulated Ollama model information
18
+ LIGHTRAG_NAME = "lightrag"
19
+ LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
20
+ LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
21
+ LIGHTRAG_SIZE = 7365960935 # it's a dummy value
22
+ LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
23
+ LIGHTRAG_DIGEST = "sha256:lightrag"
24
+
25
+
26
+ ollama_server_infos = OllamaServerInfos()
27
+
28
+
29
+ class DefaultRAGStorageConfig:
30
+ KV_STORAGE = "JsonKVStorage"
31
+ VECTOR_STORAGE = "NanoVectorDBStorage"
32
+ GRAPH_STORAGE = "NetworkXStorage"
33
+ DOC_STATUS_STORAGE = "JsonDocStatusStorage"
34
+
35
+
36
+ def get_default_host(binding_type: str) -> str:
37
+ default_hosts = {
38
+ "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
39
+ "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
40
+ "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
41
+ "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
42
+ }
43
+ return default_hosts.get(
44
+ binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
45
+ ) # fallback to ollama if unknown
46
+
47
+
48
+ def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
49
+ """
50
+ Get value from environment variable with type conversion
51
+
52
+ Args:
53
+ env_key (str): Environment variable key
54
+ default (any): Default value if env variable is not set
55
+ value_type (type): Type to convert the value to
56
+
57
+ Returns:
58
+ any: Converted value from environment or default
59
+ """
60
+ value = os.getenv(env_key)
61
+ if value is None:
62
+ return default
63
+
64
+ if value_type is bool:
65
+ return value.lower() in ("true", "1", "yes", "t", "on")
66
+ try:
67
+ return value_type(value)
68
+ except ValueError:
69
+ return default
70
+
71
+
72
+ def parse_args() -> argparse.Namespace:
73
+ """
74
+ Parse command line arguments with environment variable fallback
75
+
76
+ Args:
77
+ is_uvicorn_mode: Whether running under uvicorn mode
78
+
79
+ Returns:
80
+ argparse.Namespace: Parsed arguments
81
+ """
82
+
83
+ parser = argparse.ArgumentParser(
84
+ description="LightRAG FastAPI Server with separate working and input directories"
85
+ )
86
+
87
+ # Server configuration
88
+ parser.add_argument(
89
+ "--host",
90
+ default=get_env_value("HOST", "0.0.0.0"),
91
+ help="Server host (default: from env or 0.0.0.0)",
92
+ )
93
+ parser.add_argument(
94
+ "--port",
95
+ type=int,
96
+ default=get_env_value("PORT", 9621, int),
97
+ help="Server port (default: from env or 9621)",
98
+ )
99
+
100
+ # Directory configuration
101
+ parser.add_argument(
102
+ "--working-dir",
103
+ default=get_env_value("WORKING_DIR", "./rag_storage"),
104
+ help="Working directory for RAG storage (default: from env or ./rag_storage)",
105
+ )
106
+ parser.add_argument(
107
+ "--input-dir",
108
+ default=get_env_value("INPUT_DIR", "./inputs"),
109
+ help="Directory containing input documents (default: from env or ./inputs)",
110
+ )
111
+
112
+ def timeout_type(value):
113
+ if value is None:
114
+ return 150
115
+ if value is None or value == "None":
116
+ return None
117
+ return int(value)
118
+
119
+ parser.add_argument(
120
+ "--timeout",
121
+ default=get_env_value("TIMEOUT", None, timeout_type),
122
+ type=timeout_type,
123
+ help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
124
+ )
125
+
126
+ # RAG configuration
127
+ parser.add_argument(
128
+ "--max-async",
129
+ type=int,
130
+ default=get_env_value("MAX_ASYNC", 4, int),
131
+ help="Maximum async operations (default: from env or 4)",
132
+ )
133
+ parser.add_argument(
134
+ "--max-tokens",
135
+ type=int,
136
+ default=get_env_value("MAX_TOKENS", 32768, int),
137
+ help="Maximum token size (default: from env or 32768)",
138
+ )
139
+
140
+ # Logging configuration
141
+ parser.add_argument(
142
+ "--log-level",
143
+ default=get_env_value("LOG_LEVEL", "INFO"),
144
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
145
+ help="Logging level (default: from env or INFO)",
146
+ )
147
+ parser.add_argument(
148
+ "--verbose",
149
+ action="store_true",
150
+ default=get_env_value("VERBOSE", False, bool),
151
+ help="Enable verbose debug output(only valid for DEBUG log-level)",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--key",
156
+ type=str,
157
+ default=get_env_value("LIGHTRAG_API_KEY", None),
158
+ help="API key for authentication. This protects lightrag server against unauthorized access",
159
+ )
160
+
161
+ # Optional https parameters
162
+ parser.add_argument(
163
+ "--ssl",
164
+ action="store_true",
165
+ default=get_env_value("SSL", False, bool),
166
+ help="Enable HTTPS (default: from env or False)",
167
+ )
168
+ parser.add_argument(
169
+ "--ssl-certfile",
170
+ default=get_env_value("SSL_CERTFILE", None),
171
+ help="Path to SSL certificate file (required if --ssl is enabled)",
172
+ )
173
+ parser.add_argument(
174
+ "--ssl-keyfile",
175
+ default=get_env_value("SSL_KEYFILE", None),
176
+ help="Path to SSL private key file (required if --ssl is enabled)",
177
+ )
178
+
179
+ parser.add_argument(
180
+ "--history-turns",
181
+ type=int,
182
+ default=get_env_value("HISTORY_TURNS", 3, int),
183
+ help="Number of conversation history turns to include (default: from env or 3)",
184
+ )
185
+
186
+ # Search parameters
187
+ parser.add_argument(
188
+ "--top-k",
189
+ type=int,
190
+ default=get_env_value("TOP_K", 60, int),
191
+ help="Number of most similar results to return (default: from env or 60)",
192
+ )
193
+ parser.add_argument(
194
+ "--cosine-threshold",
195
+ type=float,
196
+ default=get_env_value("COSINE_THRESHOLD", 0.2, float),
197
+ help="Cosine similarity threshold (default: from env or 0.4)",
198
+ )
199
+
200
+ # Ollama model name
201
+ parser.add_argument(
202
+ "--simulated-model-name",
203
+ type=str,
204
+ default=get_env_value(
205
+ "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
206
+ ),
207
+ help="Number of conversation history turns to include (default: from env or 3)",
208
+ )
209
+
210
+ # Namespace
211
+ parser.add_argument(
212
+ "--namespace-prefix",
213
+ type=str,
214
+ default=get_env_value("NAMESPACE_PREFIX", ""),
215
+ help="Prefix of the namespace",
216
+ )
217
+
218
+ parser.add_argument(
219
+ "--auto-scan-at-startup",
220
+ action="store_true",
221
+ default=False,
222
+ help="Enable automatic scanning when the program starts",
223
+ )
224
+
225
+ # Server workers configuration
226
+ parser.add_argument(
227
+ "--workers",
228
+ type=int,
229
+ default=get_env_value("WORKERS", 1, int),
230
+ help="Number of worker processes (default: from env or 1)",
231
+ )
232
+
233
+ # LLM and embedding bindings
234
+ parser.add_argument(
235
+ "--llm-binding",
236
+ type=str,
237
+ default=get_env_value("LLM_BINDING", "ollama"),
238
+ choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
239
+ help="LLM binding type (default: from env or ollama)",
240
+ )
241
+ parser.add_argument(
242
+ "--embedding-binding",
243
+ type=str,
244
+ default=get_env_value("EMBEDDING_BINDING", "ollama"),
245
+ choices=["lollms", "ollama", "openai", "azure_openai"],
246
+ help="Embedding binding type (default: from env or ollama)",
247
+ )
248
+
249
+ args = parser.parse_args()
250
+
251
+ # convert relative path to absolute path
252
+ args.working_dir = os.path.abspath(args.working_dir)
253
+ args.input_dir = os.path.abspath(args.input_dir)
254
+
255
+ # Inject storage configuration from environment variables
256
+ args.kv_storage = get_env_value(
257
+ "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
258
+ )
259
+ args.doc_status_storage = get_env_value(
260
+ "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
261
+ )
262
+ args.graph_storage = get_env_value(
263
+ "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
264
+ )
265
+ args.vector_storage = get_env_value(
266
+ "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
267
+ )
268
+
269
+ # Get MAX_PARALLEL_INSERT from environment
270
+ args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
271
+
272
+ # Handle openai-ollama special case
273
+ if args.llm_binding == "openai-ollama":
274
+ args.llm_binding = "openai"
275
+ args.embedding_binding = "ollama"
276
+
277
+ args.llm_binding_host = get_env_value(
278
+ "LLM_BINDING_HOST", get_default_host(args.llm_binding)
279
+ )
280
+ args.embedding_binding_host = get_env_value(
281
+ "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
282
+ )
283
+ args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
284
+ args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
285
+
286
+ # Inject model configuration
287
+ args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
288
+ args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
289
+ args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
290
+ args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
291
+
292
+ # Inject chunk configuration
293
+ args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
294
+ args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
295
+
296
+ # Inject LLM cache configuration
297
+ args.enable_llm_cache_for_extract = get_env_value(
298
+ "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
299
+ )
300
+
301
+ # Inject LLM temperature configuration
302
+ args.temperature = get_env_value("TEMPERATURE", 0.5, float)
303
+
304
+ # Select Document loading tool (DOCLING, DEFAULT)
305
+ args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
306
+
307
+ # Add environment variables that were previously read directly
308
+ args.cors_origins = get_env_value("CORS_ORIGINS", "*")
309
+ args.summary_language = get_env_value("SUMMARY_LANGUAGE", "en")
310
+ args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
311
+
312
+ # For JWT Auth
313
+ args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "")
314
+ args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret")
315
+ args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int)
316
+ args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
317
+ args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
318
+
319
+ ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
320
+
321
+ return args
322
+
323
+
324
+ def update_uvicorn_mode_config():
325
+ # If in uvicorn mode and workers > 1, force it to 1 and log warning
326
+ if global_args.workers > 1:
327
+ original_workers = global_args.workers
328
+ global_args.workers = 1
329
+ # Log warning directly here
330
+ logging.warning(
331
+ f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
332
+ )
333
+
334
+
335
+ global_args = parse_args()
lightrag/api/lightrag_server.py CHANGED
@@ -19,11 +19,14 @@ from contextlib import asynccontextmanager
19
  from dotenv import load_dotenv
20
  from lightrag.api.utils_api import (
21
  get_combined_auth_dependency,
22
- parse_args,
23
- get_default_host,
24
  display_splash_screen,
25
  check_env_file,
26
  )
 
 
 
 
 
27
  import sys
28
  from lightrag import LightRAG, __version__ as core_version
29
  from lightrag.api import __api_version__
@@ -52,6 +55,10 @@ from lightrag.api.auth import auth_handler
52
  # the OS environment variables take precedence over the .env file
53
  load_dotenv(dotenv_path=".env", override=False)
54
 
 
 
 
 
55
  # Initialize config parser
56
  config = configparser.ConfigParser()
57
  config.read("config.ini")
@@ -164,10 +171,10 @@ def create_app(args):
164
  app = FastAPI(**app_kwargs)
165
 
166
  def get_cors_origins():
167
- """Get allowed origins from environment variable
168
  Returns a list of allowed origins, defaults to ["*"] if not set
169
  """
170
- origins_str = os.getenv("CORS_ORIGINS", "*")
171
  if origins_str == "*":
172
  return ["*"]
173
  return [origin.strip() for origin in origins_str.split(",")]
@@ -315,9 +322,10 @@ def create_app(args):
315
  "similarity_threshold": 0.95,
316
  "use_llm_check": False,
317
  },
318
- namespace_prefix=args.namespace_prefix,
319
  auto_manage_storages_states=False,
320
  max_parallel_insert=args.max_parallel_insert,
 
321
  )
322
  else: # azure_openai
323
  rag = LightRAG(
@@ -345,9 +353,10 @@ def create_app(args):
345
  "similarity_threshold": 0.95,
346
  "use_llm_check": False,
347
  },
348
- namespace_prefix=args.namespace_prefix,
349
  auto_manage_storages_states=False,
350
  max_parallel_insert=args.max_parallel_insert,
 
351
  )
352
 
353
  # Add routes
@@ -381,6 +390,8 @@ def create_app(args):
381
  "message": "Authentication is disabled. Using guest access.",
382
  "core_version": core_version,
383
  "api_version": __api_version__,
 
 
384
  }
385
 
386
  return {
@@ -388,6 +399,8 @@ def create_app(args):
388
  "auth_mode": "enabled",
389
  "core_version": core_version,
390
  "api_version": __api_version__,
 
 
391
  }
392
 
393
  @app.post("/login")
@@ -404,6 +417,8 @@ def create_app(args):
404
  "message": "Authentication is disabled. Using guest access.",
405
  "core_version": core_version,
406
  "api_version": __api_version__,
 
 
407
  }
408
  username = form_data.username
409
  if auth_handler.accounts.get(username) != form_data.password:
@@ -421,6 +436,8 @@ def create_app(args):
421
  "auth_mode": "enabled",
422
  "core_version": core_version,
423
  "api_version": __api_version__,
 
 
424
  }
425
 
426
  @app.get("/health", dependencies=[Depends(combined_auth)])
@@ -454,10 +471,12 @@ def create_app(args):
454
  "vector_storage": args.vector_storage,
455
  "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
456
  },
457
- "core_version": core_version,
458
- "api_version": __api_version__,
459
  "auth_mode": auth_mode,
460
  "pipeline_busy": pipeline_status.get("busy", False),
 
 
 
 
461
  }
462
  except Exception as e:
463
  logger.error(f"Error getting health status: {str(e)}")
@@ -490,7 +509,7 @@ def create_app(args):
490
  def get_application(args=None):
491
  """Factory function for creating the FastAPI application"""
492
  if args is None:
493
- args = parse_args()
494
  return create_app(args)
495
 
496
 
@@ -611,30 +630,31 @@ def main():
611
 
612
  # Configure logging before parsing args
613
  configure_logging()
614
-
615
- args = parse_args(is_uvicorn_mode=True)
616
- display_splash_screen(args)
617
 
618
  # Create application instance directly instead of using factory function
619
- app = create_app(args)
620
 
621
  # Start Uvicorn in single process mode
622
  uvicorn_config = {
623
  "app": app, # Pass application instance directly instead of string path
624
- "host": args.host,
625
- "port": args.port,
626
  "log_config": None, # Disable default config
627
  }
628
 
629
- if args.ssl:
630
  uvicorn_config.update(
631
  {
632
- "ssl_certfile": args.ssl_certfile,
633
- "ssl_keyfile": args.ssl_keyfile,
634
  }
635
  )
636
 
637
- print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
 
 
638
  uvicorn.run(**uvicorn_config)
639
 
640
 
 
19
  from dotenv import load_dotenv
20
  from lightrag.api.utils_api import (
21
  get_combined_auth_dependency,
 
 
22
  display_splash_screen,
23
  check_env_file,
24
  )
25
+ from .config import (
26
+ global_args,
27
+ update_uvicorn_mode_config,
28
+ get_default_host,
29
+ )
30
  import sys
31
  from lightrag import LightRAG, __version__ as core_version
32
  from lightrag.api import __api_version__
 
55
  # the OS environment variables take precedence over the .env file
56
  load_dotenv(dotenv_path=".env", override=False)
57
 
58
+
59
+ webui_title = os.getenv("WEBUI_TITLE")
60
+ webui_description = os.getenv("WEBUI_DESCRIPTION")
61
+
62
  # Initialize config parser
63
  config = configparser.ConfigParser()
64
  config.read("config.ini")
 
171
  app = FastAPI(**app_kwargs)
172
 
173
  def get_cors_origins():
174
+ """Get allowed origins from global_args
175
  Returns a list of allowed origins, defaults to ["*"] if not set
176
  """
177
+ origins_str = global_args.cors_origins
178
  if origins_str == "*":
179
  return ["*"]
180
  return [origin.strip() for origin in origins_str.split(",")]
 
322
  "similarity_threshold": 0.95,
323
  "use_llm_check": False,
324
  },
325
+ # namespace_prefix=args.namespace_prefix,
326
  auto_manage_storages_states=False,
327
  max_parallel_insert=args.max_parallel_insert,
328
+ addon_params={"language": args.summary_language},
329
  )
330
  else: # azure_openai
331
  rag = LightRAG(
 
353
  "similarity_threshold": 0.95,
354
  "use_llm_check": False,
355
  },
356
+ # namespace_prefix=args.namespace_prefix,
357
  auto_manage_storages_states=False,
358
  max_parallel_insert=args.max_parallel_insert,
359
+ addon_params={"language": args.summary_language},
360
  )
361
 
362
  # Add routes
 
390
  "message": "Authentication is disabled. Using guest access.",
391
  "core_version": core_version,
392
  "api_version": __api_version__,
393
+ "webui_title": webui_title,
394
+ "webui_description": webui_description,
395
  }
396
 
397
  return {
 
399
  "auth_mode": "enabled",
400
  "core_version": core_version,
401
  "api_version": __api_version__,
402
+ "webui_title": webui_title,
403
+ "webui_description": webui_description,
404
  }
405
 
406
  @app.post("/login")
 
417
  "message": "Authentication is disabled. Using guest access.",
418
  "core_version": core_version,
419
  "api_version": __api_version__,
420
+ "webui_title": webui_title,
421
+ "webui_description": webui_description,
422
  }
423
  username = form_data.username
424
  if auth_handler.accounts.get(username) != form_data.password:
 
436
  "auth_mode": "enabled",
437
  "core_version": core_version,
438
  "api_version": __api_version__,
439
+ "webui_title": webui_title,
440
+ "webui_description": webui_description,
441
  }
442
 
443
  @app.get("/health", dependencies=[Depends(combined_auth)])
 
471
  "vector_storage": args.vector_storage,
472
  "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
473
  },
 
 
474
  "auth_mode": auth_mode,
475
  "pipeline_busy": pipeline_status.get("busy", False),
476
+ "core_version": core_version,
477
+ "api_version": __api_version__,
478
+ "webui_title": webui_title,
479
+ "webui_description": webui_description,
480
  }
481
  except Exception as e:
482
  logger.error(f"Error getting health status: {str(e)}")
 
509
  def get_application(args=None):
510
  """Factory function for creating the FastAPI application"""
511
  if args is None:
512
+ args = global_args
513
  return create_app(args)
514
 
515
 
 
630
 
631
  # Configure logging before parsing args
632
  configure_logging()
633
+ update_uvicorn_mode_config()
634
+ display_splash_screen(global_args)
 
635
 
636
  # Create application instance directly instead of using factory function
637
+ app = create_app(global_args)
638
 
639
  # Start Uvicorn in single process mode
640
  uvicorn_config = {
641
  "app": app, # Pass application instance directly instead of string path
642
+ "host": global_args.host,
643
+ "port": global_args.port,
644
  "log_config": None, # Disable default config
645
  }
646
 
647
+ if global_args.ssl:
648
  uvicorn_config.update(
649
  {
650
+ "ssl_certfile": global_args.ssl_certfile,
651
+ "ssl_keyfile": global_args.ssl_keyfile,
652
  }
653
  )
654
 
655
+ print(
656
+ f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
657
+ )
658
  uvicorn.run(**uvicorn_config)
659
 
660
 
lightrag/api/routers/document_routes.py CHANGED
@@ -10,16 +10,14 @@ import traceback
10
  import pipmaster as pm
11
  from datetime import datetime
12
  from pathlib import Path
13
- from typing import Dict, List, Optional, Any
14
  from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
15
  from pydantic import BaseModel, Field, field_validator
16
 
17
  from lightrag import LightRAG
18
  from lightrag.base import DocProcessingStatus, DocStatus
19
- from lightrag.api.utils_api import (
20
- get_combined_auth_dependency,
21
- global_args,
22
- )
23
 
24
  router = APIRouter(
25
  prefix="/documents",
@@ -30,7 +28,37 @@ router = APIRouter(
30
  temp_prefix = "__tmp__"
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class InsertTextRequest(BaseModel):
 
 
 
 
 
 
34
  text: str = Field(
35
  min_length=1,
36
  description="The text to insert",
@@ -41,8 +69,21 @@ class InsertTextRequest(BaseModel):
41
  def strip_after(cls, text: str) -> str:
42
  return text.strip()
43
 
 
 
 
 
 
 
 
44
 
45
  class InsertTextsRequest(BaseModel):
 
 
 
 
 
 
46
  texts: list[str] = Field(
47
  min_length=1,
48
  description="The texts to insert",
@@ -53,11 +94,116 @@ class InsertTextsRequest(BaseModel):
53
  def strip_after(cls, texts: list[str]) -> list[str]:
54
  return [text.strip() for text in texts]
55
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  class InsertResponse(BaseModel):
58
- status: str = Field(description="Status of the operation")
 
 
 
 
 
 
 
 
 
59
  message: str = Field(description="Message describing the operation result")
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  class DocStatusResponse(BaseModel):
63
  @staticmethod
@@ -68,34 +214,82 @@ class DocStatusResponse(BaseModel):
68
  return dt
69
  return dt.isoformat()
70
 
71
- """Response model for document status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  Attributes:
74
- id: Document identifier
75
- content_summary: Summary of document content
76
- content_length: Length of document content
77
- status: Current processing status
78
- created_at: Creation timestamp (ISO format string)
79
- updated_at: Last update timestamp (ISO format string)
80
- chunks_count: Number of chunks (optional)
81
- error: Error message if any (optional)
82
- metadata: Additional metadata (optional)
83
  """
84
 
85
- id: str
86
- content_summary: str
87
- content_length: int
88
- status: DocStatus
89
- created_at: str
90
- updated_at: str
91
- chunks_count: Optional[int] = None
92
- error: Optional[str] = None
93
- metadata: Optional[dict[str, Any]] = None
94
- file_path: str
95
-
96
 
97
- class DocsStatusesResponse(BaseModel):
98
- statuses: Dict[DocStatus, List[DocStatusResponse]] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
101
  class PipelineStatusResponse(BaseModel):
@@ -276,7 +470,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
276
  )
277
  return False
278
  case ".pdf":
279
- if global_args["main_args"].document_loading_engine == "DOCLING":
280
  if not pm.is_installed("docling"): # type: ignore
281
  pm.install("docling")
282
  from docling.document_converter import DocumentConverter # type: ignore
@@ -295,7 +489,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
295
  for page in reader.pages:
296
  content += page.extract_text() + "\n"
297
  case ".docx":
298
- if global_args["main_args"].document_loading_engine == "DOCLING":
299
  if not pm.is_installed("docling"): # type: ignore
300
  pm.install("docling")
301
  from docling.document_converter import DocumentConverter # type: ignore
@@ -315,7 +509,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
315
  [paragraph.text for paragraph in doc.paragraphs]
316
  )
317
  case ".pptx":
318
- if global_args["main_args"].document_loading_engine == "DOCLING":
319
  if not pm.is_installed("docling"): # type: ignore
320
  pm.install("docling")
321
  from docling.document_converter import DocumentConverter # type: ignore
@@ -336,7 +530,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
336
  if hasattr(shape, "text"):
337
  content += shape.text + "\n"
338
  case ".xlsx":
339
- if global_args["main_args"].document_loading_engine == "DOCLING":
340
  if not pm.is_installed("docling"): # type: ignore
341
  pm.install("docling")
342
  from docling.document_converter import DocumentConverter # type: ignore
@@ -443,6 +637,7 @@ async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
443
  await rag.apipeline_process_enqueue_documents()
444
 
445
 
 
446
  async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
447
  """Save the uploaded file to a temporary location
448
 
@@ -476,8 +671,8 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
476
  if not new_files:
477
  return
478
 
479
- # Get MAX_PARALLEL_INSERT from global_args["main_args"]
480
- max_parallel = global_args["main_args"].max_parallel_insert
481
  # Calculate batch size as 2 * MAX_PARALLEL_INSERT
482
  batch_size = 2 * max_parallel
483
 
@@ -509,7 +704,9 @@ def create_document_routes(
509
  # Create combined auth dependency for document routes
510
  combined_auth = get_combined_auth_dependency(api_key)
511
 
512
- @router.post("/scan", dependencies=[Depends(combined_auth)])
 
 
513
  async def scan_for_new_documents(background_tasks: BackgroundTasks):
514
  """
515
  Trigger the scanning process for new documents.
@@ -519,13 +716,18 @@ def create_document_routes(
519
  that fact.
520
 
521
  Returns:
522
- dict: A dictionary containing the scanning status
523
  """
524
  # Start the scanning process in the background
525
  background_tasks.add_task(run_scanning_process, rag, doc_manager)
526
- return {"status": "scanning_started"}
 
 
 
527
 
528
- @router.post("/upload", dependencies=[Depends(combined_auth)])
 
 
529
  async def upload_to_input_dir(
530
  background_tasks: BackgroundTasks, file: UploadFile = File(...)
531
  ):
@@ -645,6 +847,7 @@ def create_document_routes(
645
  logger.error(traceback.format_exc())
646
  raise HTTPException(status_code=500, detail=str(e))
647
 
 
648
  @router.post(
649
  "/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
650
  )
@@ -688,6 +891,7 @@ def create_document_routes(
688
  logger.error(traceback.format_exc())
689
  raise HTTPException(status_code=500, detail=str(e))
690
 
 
691
  @router.post(
692
  "/file_batch",
693
  response_model=InsertResponse,
@@ -752,32 +956,186 @@ def create_document_routes(
752
  raise HTTPException(status_code=500, detail=str(e))
753
 
754
  @router.delete(
755
- "", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
756
  )
757
  async def clear_documents():
758
  """
759
  Clear all documents from the RAG system.
760
 
761
- This endpoint deletes all text chunks, entities vector database, and relationships
762
- vector database, effectively clearing all documents from the RAG system.
 
763
 
764
  Returns:
765
- InsertResponse: A response object containing the status and message.
 
 
 
 
 
 
766
 
767
  Raises:
768
- HTTPException: If an error occurs during the clearing process (500).
 
769
  """
770
- try:
771
- rag.text_chunks = []
772
- rag.entities_vdb = None
773
- rag.relationships_vdb = None
774
- return InsertResponse(
775
- status="success", message="All documents cleared successfully"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777
  except Exception as e:
778
- logger.error(f"Error DELETE /documents: {str(e)}")
 
779
  logger.error(traceback.format_exc())
 
 
780
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
781
 
782
  @router.get(
783
  "/pipeline_status",
@@ -850,7 +1208,9 @@ def create_document_routes(
850
  logger.error(traceback.format_exc())
851
  raise HTTPException(status_code=500, detail=str(e))
852
 
853
- @router.get("", dependencies=[Depends(combined_auth)])
 
 
854
  async def documents() -> DocsStatusesResponse:
855
  """
856
  Get the status of all documents in the system.
@@ -908,4 +1268,57 @@ def create_document_routes(
908
  logger.error(traceback.format_exc())
909
  raise HTTPException(status_code=500, detail=str(e))
910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  return router
 
10
  import pipmaster as pm
11
  from datetime import datetime
12
  from pathlib import Path
13
+ from typing import Dict, List, Optional, Any, Literal
14
  from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
15
  from pydantic import BaseModel, Field, field_validator
16
 
17
  from lightrag import LightRAG
18
  from lightrag.base import DocProcessingStatus, DocStatus
19
+ from lightrag.api.utils_api import get_combined_auth_dependency
20
+ from ..config import global_args
 
 
21
 
22
  router = APIRouter(
23
  prefix="/documents",
 
28
  temp_prefix = "__tmp__"
29
 
30
 
31
+ class ScanResponse(BaseModel):
32
+ """Response model for document scanning operation
33
+
34
+ Attributes:
35
+ status: Status of the scanning operation
36
+ message: Optional message with additional details
37
+ """
38
+
39
+ status: Literal["scanning_started"] = Field(
40
+ description="Status of the scanning operation"
41
+ )
42
+ message: Optional[str] = Field(
43
+ default=None, description="Additional details about the scanning operation"
44
+ )
45
+
46
+ class Config:
47
+ json_schema_extra = {
48
+ "example": {
49
+ "status": "scanning_started",
50
+ "message": "Scanning process has been initiated in the background",
51
+ }
52
+ }
53
+
54
+
55
  class InsertTextRequest(BaseModel):
56
+ """Request model for inserting a single text document
57
+
58
+ Attributes:
59
+ text: The text content to be inserted into the RAG system
60
+ """
61
+
62
  text: str = Field(
63
  min_length=1,
64
  description="The text to insert",
 
69
  def strip_after(cls, text: str) -> str:
70
  return text.strip()
71
 
72
+ class Config:
73
+ json_schema_extra = {
74
+ "example": {
75
+ "text": "This is a sample text to be inserted into the RAG system."
76
+ }
77
+ }
78
+
79
 
80
  class InsertTextsRequest(BaseModel):
81
+ """Request model for inserting multiple text documents
82
+
83
+ Attributes:
84
+ texts: List of text contents to be inserted into the RAG system
85
+ """
86
+
87
  texts: list[str] = Field(
88
  min_length=1,
89
  description="The texts to insert",
 
94
  def strip_after(cls, texts: list[str]) -> list[str]:
95
  return [text.strip() for text in texts]
96
 
97
+ class Config:
98
+ json_schema_extra = {
99
+ "example": {
100
+ "texts": [
101
+ "This is the first text to be inserted.",
102
+ "This is the second text to be inserted.",
103
+ ]
104
+ }
105
+ }
106
+
107
 
108
  class InsertResponse(BaseModel):
109
+ """Response model for document insertion operations
110
+
111
+ Attributes:
112
+ status: Status of the operation (success, duplicated, partial_success, failure)
113
+ message: Detailed message describing the operation result
114
+ """
115
+
116
+ status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
117
+ description="Status of the operation"
118
+ )
119
  message: str = Field(description="Message describing the operation result")
120
 
121
+ class Config:
122
+ json_schema_extra = {
123
+ "example": {
124
+ "status": "success",
125
+ "message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
126
+ }
127
+ }
128
+
129
+
130
+ class ClearDocumentsResponse(BaseModel):
131
+ """Response model for document clearing operation
132
+
133
+ Attributes:
134
+ status: Status of the clear operation
135
+ message: Detailed message describing the operation result
136
+ """
137
+
138
+ status: Literal["success", "partial_success", "busy", "fail"] = Field(
139
+ description="Status of the clear operation"
140
+ )
141
+ message: str = Field(description="Message describing the operation result")
142
+
143
+ class Config:
144
+ json_schema_extra = {
145
+ "example": {
146
+ "status": "success",
147
+ "message": "All documents cleared successfully. Deleted 15 files.",
148
+ }
149
+ }
150
+
151
+
152
+ class ClearCacheRequest(BaseModel):
153
+ """Request model for clearing cache
154
+
155
+ Attributes:
156
+ modes: Optional list of cache modes to clear
157
+ """
158
+
159
+ modes: Optional[
160
+ List[Literal["default", "naive", "local", "global", "hybrid", "mix"]]
161
+ ] = Field(
162
+ default=None,
163
+ description="Modes of cache to clear. If None, clears all cache.",
164
+ )
165
+
166
+ class Config:
167
+ json_schema_extra = {"example": {"modes": ["default", "naive"]}}
168
+
169
+
170
+ class ClearCacheResponse(BaseModel):
171
+ """Response model for cache clearing operation
172
+
173
+ Attributes:
174
+ status: Status of the clear operation
175
+ message: Detailed message describing the operation result
176
+ """
177
+
178
+ status: Literal["success", "fail"] = Field(
179
+ description="Status of the clear operation"
180
+ )
181
+ message: str = Field(description="Message describing the operation result")
182
+
183
+ class Config:
184
+ json_schema_extra = {
185
+ "example": {
186
+ "status": "success",
187
+ "message": "Successfully cleared cache for modes: ['default', 'naive']",
188
+ }
189
+ }
190
+
191
+
192
+ """Response model for document status
193
+
194
+ Attributes:
195
+ id: Document identifier
196
+ content_summary: Summary of document content
197
+ content_length: Length of document content
198
+ status: Current processing status
199
+ created_at: Creation timestamp (ISO format string)
200
+ updated_at: Last update timestamp (ISO format string)
201
+ chunks_count: Number of chunks (optional)
202
+ error: Error message if any (optional)
203
+ metadata: Additional metadata (optional)
204
+ file_path: Path to the document file
205
+ """
206
+
207
 
208
  class DocStatusResponse(BaseModel):
209
  @staticmethod
 
214
  return dt
215
  return dt.isoformat()
216
 
217
+ id: str = Field(description="Document identifier")
218
+ content_summary: str = Field(description="Summary of document content")
219
+ content_length: int = Field(description="Length of document content in characters")
220
+ status: DocStatus = Field(description="Current processing status")
221
+ created_at: str = Field(description="Creation timestamp (ISO format string)")
222
+ updated_at: str = Field(description="Last update timestamp (ISO format string)")
223
+ chunks_count: Optional[int] = Field(
224
+ default=None, description="Number of chunks the document was split into"
225
+ )
226
+ error: Optional[str] = Field(
227
+ default=None, description="Error message if processing failed"
228
+ )
229
+ metadata: Optional[dict[str, Any]] = Field(
230
+ default=None, description="Additional metadata about the document"
231
+ )
232
+ file_path: str = Field(description="Path to the document file")
233
+
234
+ class Config:
235
+ json_schema_extra = {
236
+ "example": {
237
+ "id": "doc_123456",
238
+ "content_summary": "Research paper on machine learning",
239
+ "content_length": 15240,
240
+ "status": "PROCESSED",
241
+ "created_at": "2025-03-31T12:34:56",
242
+ "updated_at": "2025-03-31T12:35:30",
243
+ "chunks_count": 12,
244
+ "error": None,
245
+ "metadata": {"author": "John Doe", "year": 2025},
246
+ "file_path": "research_paper.pdf",
247
+ }
248
+ }
249
+
250
+
251
+ class DocsStatusesResponse(BaseModel):
252
+ """Response model for document statuses
253
 
254
  Attributes:
255
+ statuses: Dictionary mapping document status to lists of document status responses
 
 
 
 
 
 
 
 
256
  """
257
 
258
+ statuses: Dict[DocStatus, List[DocStatusResponse]] = Field(
259
+ default_factory=dict,
260
+ description="Dictionary mapping document status to lists of document status responses",
261
+ )
 
 
 
 
 
 
 
262
 
263
+ class Config:
264
+ json_schema_extra = {
265
+ "example": {
266
+ "statuses": {
267
+ "PENDING": [
268
+ {
269
+ "id": "doc_123",
270
+ "content_summary": "Pending document",
271
+ "content_length": 5000,
272
+ "status": "PENDING",
273
+ "created_at": "2025-03-31T10:00:00",
274
+ "updated_at": "2025-03-31T10:00:00",
275
+ "file_path": "pending_doc.pdf",
276
+ }
277
+ ],
278
+ "PROCESSED": [
279
+ {
280
+ "id": "doc_456",
281
+ "content_summary": "Processed document",
282
+ "content_length": 8000,
283
+ "status": "PROCESSED",
284
+ "created_at": "2025-03-31T09:00:00",
285
+ "updated_at": "2025-03-31T09:05:00",
286
+ "chunks_count": 8,
287
+ "file_path": "processed_doc.pdf",
288
+ }
289
+ ],
290
+ }
291
+ }
292
+ }
293
 
294
 
295
  class PipelineStatusResponse(BaseModel):
 
470
  )
471
  return False
472
  case ".pdf":
473
+ if global_args.document_loading_engine == "DOCLING":
474
  if not pm.is_installed("docling"): # type: ignore
475
  pm.install("docling")
476
  from docling.document_converter import DocumentConverter # type: ignore
 
489
  for page in reader.pages:
490
  content += page.extract_text() + "\n"
491
  case ".docx":
492
+ if global_args.document_loading_engine == "DOCLING":
493
  if not pm.is_installed("docling"): # type: ignore
494
  pm.install("docling")
495
  from docling.document_converter import DocumentConverter # type: ignore
 
509
  [paragraph.text for paragraph in doc.paragraphs]
510
  )
511
  case ".pptx":
512
+ if global_args.document_loading_engine == "DOCLING":
513
  if not pm.is_installed("docling"): # type: ignore
514
  pm.install("docling")
515
  from docling.document_converter import DocumentConverter # type: ignore
 
530
  if hasattr(shape, "text"):
531
  content += shape.text + "\n"
532
  case ".xlsx":
533
+ if global_args.document_loading_engine == "DOCLING":
534
  if not pm.is_installed("docling"): # type: ignore
535
  pm.install("docling")
536
  from docling.document_converter import DocumentConverter # type: ignore
 
637
  await rag.apipeline_process_enqueue_documents()
638
 
639
 
640
+ # TODO: deprecate after /insert_file is removed
641
  async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
642
  """Save the uploaded file to a temporary location
643
 
 
671
  if not new_files:
672
  return
673
 
674
+ # Get MAX_PARALLEL_INSERT from global_args
675
+ max_parallel = global_args.max_parallel_insert
676
  # Calculate batch size as 2 * MAX_PARALLEL_INSERT
677
  batch_size = 2 * max_parallel
678
 
 
704
  # Create combined auth dependency for document routes
705
  combined_auth = get_combined_auth_dependency(api_key)
706
 
707
+ @router.post(
708
+ "/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)]
709
+ )
710
  async def scan_for_new_documents(background_tasks: BackgroundTasks):
711
  """
712
  Trigger the scanning process for new documents.
 
716
  that fact.
717
 
718
  Returns:
719
+ ScanResponse: A response object containing the scanning status
720
  """
721
  # Start the scanning process in the background
722
  background_tasks.add_task(run_scanning_process, rag, doc_manager)
723
+ return ScanResponse(
724
+ status="scanning_started",
725
+ message="Scanning process has been initiated in the background",
726
+ )
727
 
728
+ @router.post(
729
+ "/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
730
+ )
731
  async def upload_to_input_dir(
732
  background_tasks: BackgroundTasks, file: UploadFile = File(...)
733
  ):
 
847
  logger.error(traceback.format_exc())
848
  raise HTTPException(status_code=500, detail=str(e))
849
 
850
+ # TODO: deprecated, use /upload instead
851
  @router.post(
852
  "/file", response_model=InsertResponse, dependencies=[Depends(combined_auth)]
853
  )
 
891
  logger.error(traceback.format_exc())
892
  raise HTTPException(status_code=500, detail=str(e))
893
 
894
+ # TODO: deprecated, use /upload instead
895
  @router.post(
896
  "/file_batch",
897
  response_model=InsertResponse,
 
956
  raise HTTPException(status_code=500, detail=str(e))
957
 
958
  @router.delete(
959
+ "", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)]
960
  )
961
  async def clear_documents():
962
  """
963
  Clear all documents from the RAG system.
964
 
965
+ This endpoint deletes all documents, entities, relationships, and files from the system.
966
+ It uses the storage drop methods to properly clean up all data and removes all files
967
+ from the input directory.
968
 
969
  Returns:
970
+ ClearDocumentsResponse: A response object containing the status and message.
971
+ - status="success": All documents and files were successfully cleared.
972
+ - status="partial_success": Document clear job exit with some errors.
973
+ - status="busy": Operation could not be completed because the pipeline is busy.
974
+ - status="fail": All storage drop operations failed, with message
975
+ - message: Detailed information about the operation results, including counts
976
+ of deleted files and any errors encountered.
977
 
978
  Raises:
979
+ HTTPException: Raised when a serious error occurs during the clearing process,
980
+ with status code 500 and error details in the detail field.
981
  """
982
+ from lightrag.kg.shared_storage import (
983
+ get_namespace_data,
984
+ get_pipeline_status_lock,
985
+ )
986
+
987
+ # Get pipeline status and lock
988
+ pipeline_status = await get_namespace_data("pipeline_status")
989
+ pipeline_status_lock = get_pipeline_status_lock()
990
+
991
+ # Check and set status with lock
992
+ async with pipeline_status_lock:
993
+ if pipeline_status.get("busy", False):
994
+ return ClearDocumentsResponse(
995
+ status="busy",
996
+ message="Cannot clear documents while pipeline is busy",
997
+ )
998
+ # Set busy to true
999
+ pipeline_status.update(
1000
+ {
1001
+ "busy": True,
1002
+ "job_name": "Clearing Documents",
1003
+ "job_start": datetime.now().isoformat(),
1004
+ "docs": 0,
1005
+ "batchs": 0,
1006
+ "cur_batch": 0,
1007
+ "request_pending": False, # Clear any previous request
1008
+ "latest_message": "Starting document clearing process",
1009
+ }
1010
  )
1011
+ # Cleaning history_messages without breaking it as a shared list object
1012
+ del pipeline_status["history_messages"][:]
1013
+ pipeline_status["history_messages"].append(
1014
+ "Starting document clearing process"
1015
+ )
1016
+
1017
+ try:
1018
+ # Use drop method to clear all data
1019
+ drop_tasks = []
1020
+ storages = [
1021
+ rag.text_chunks,
1022
+ rag.full_docs,
1023
+ rag.entities_vdb,
1024
+ rag.relationships_vdb,
1025
+ rag.chunks_vdb,
1026
+ rag.chunk_entity_relation_graph,
1027
+ rag.doc_status,
1028
+ ]
1029
+
1030
+ # Log storage drop start
1031
+ if "history_messages" in pipeline_status:
1032
+ pipeline_status["history_messages"].append(
1033
+ "Starting to drop storage components"
1034
+ )
1035
+
1036
+ for storage in storages:
1037
+ if storage is not None:
1038
+ drop_tasks.append(storage.drop())
1039
+
1040
+ # Wait for all drop tasks to complete
1041
+ drop_results = await asyncio.gather(*drop_tasks, return_exceptions=True)
1042
+
1043
+ # Check for errors and log results
1044
+ errors = []
1045
+ storage_success_count = 0
1046
+ storage_error_count = 0
1047
+
1048
+ for i, result in enumerate(drop_results):
1049
+ storage_name = storages[i].__class__.__name__
1050
+ if isinstance(result, Exception):
1051
+ error_msg = f"Error dropping {storage_name}: {str(result)}"
1052
+ errors.append(error_msg)
1053
+ logger.error(error_msg)
1054
+ storage_error_count += 1
1055
+ else:
1056
+ logger.info(f"Successfully dropped {storage_name}")
1057
+ storage_success_count += 1
1058
+
1059
+ # Log storage drop results
1060
+ if "history_messages" in pipeline_status:
1061
+ if storage_error_count > 0:
1062
+ pipeline_status["history_messages"].append(
1063
+ f"Dropped {storage_success_count} storage components with {storage_error_count} errors"
1064
+ )
1065
+ else:
1066
+ pipeline_status["history_messages"].append(
1067
+ f"Successfully dropped all {storage_success_count} storage components"
1068
+ )
1069
+
1070
+ # If all storage operations failed, return error status and don't proceed with file deletion
1071
+ if storage_success_count == 0 and storage_error_count > 0:
1072
+ error_message = "All storage drop operations failed. Aborting document clearing process."
1073
+ logger.error(error_message)
1074
+ if "history_messages" in pipeline_status:
1075
+ pipeline_status["history_messages"].append(error_message)
1076
+ return ClearDocumentsResponse(status="fail", message=error_message)
1077
+
1078
+ # Log file deletion start
1079
+ if "history_messages" in pipeline_status:
1080
+ pipeline_status["history_messages"].append(
1081
+ "Starting to delete files in input directory"
1082
+ )
1083
+
1084
+ # Delete all files in input_dir
1085
+ deleted_files_count = 0
1086
+ file_errors_count = 0
1087
+
1088
+ for file_path in doc_manager.input_dir.glob("**/*"):
1089
+ if file_path.is_file():
1090
+ try:
1091
+ file_path.unlink()
1092
+ deleted_files_count += 1
1093
+ except Exception as e:
1094
+ logger.error(f"Error deleting file {file_path}: {str(e)}")
1095
+ file_errors_count += 1
1096
+
1097
+ # Log file deletion results
1098
+ if "history_messages" in pipeline_status:
1099
+ if file_errors_count > 0:
1100
+ pipeline_status["history_messages"].append(
1101
+ f"Deleted {deleted_files_count} files with {file_errors_count} errors"
1102
+ )
1103
+ errors.append(f"Failed to delete {file_errors_count} files")
1104
+ else:
1105
+ pipeline_status["history_messages"].append(
1106
+ f"Successfully deleted {deleted_files_count} files"
1107
+ )
1108
+
1109
+ # Prepare final result message
1110
+ final_message = ""
1111
+ if errors:
1112
+ final_message = f"Cleared documents with some errors. Deleted {deleted_files_count} files."
1113
+ status = "partial_success"
1114
+ else:
1115
+ final_message = f"All documents cleared successfully. Deleted {deleted_files_count} files."
1116
+ status = "success"
1117
+
1118
+ # Log final result
1119
+ if "history_messages" in pipeline_status:
1120
+ pipeline_status["history_messages"].append(final_message)
1121
+
1122
+ # Return response based on results
1123
+ return ClearDocumentsResponse(status=status, message=final_message)
1124
  except Exception as e:
1125
+ error_msg = f"Error clearing documents: {str(e)}"
1126
+ logger.error(error_msg)
1127
  logger.error(traceback.format_exc())
1128
+ if "history_messages" in pipeline_status:
1129
+ pipeline_status["history_messages"].append(error_msg)
1130
  raise HTTPException(status_code=500, detail=str(e))
1131
+ finally:
1132
+ # Reset busy status after completion
1133
+ async with pipeline_status_lock:
1134
+ pipeline_status["busy"] = False
1135
+ completion_msg = "Document clearing process completed"
1136
+ pipeline_status["latest_message"] = completion_msg
1137
+ if "history_messages" in pipeline_status:
1138
+ pipeline_status["history_messages"].append(completion_msg)
1139
 
1140
  @router.get(
1141
  "/pipeline_status",
 
1208
  logger.error(traceback.format_exc())
1209
  raise HTTPException(status_code=500, detail=str(e))
1210
 
1211
+ @router.get(
1212
+ "", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)]
1213
+ )
1214
  async def documents() -> DocsStatusesResponse:
1215
  """
1216
  Get the status of all documents in the system.
 
1268
  logger.error(traceback.format_exc())
1269
  raise HTTPException(status_code=500, detail=str(e))
1270
 
1271
+ @router.post(
1272
+ "/clear_cache",
1273
+ response_model=ClearCacheResponse,
1274
+ dependencies=[Depends(combined_auth)],
1275
+ )
1276
+ async def clear_cache(request: ClearCacheRequest):
1277
+ """
1278
+ Clear cache data from the LLM response cache storage.
1279
+
1280
+ This endpoint allows clearing specific modes of cache or all cache if no modes are specified.
1281
+ Valid modes include: "default", "naive", "local", "global", "hybrid", "mix".
1282
+ - "default" represents extraction cache.
1283
+ - Other modes correspond to different query modes.
1284
+
1285
+ Args:
1286
+ request (ClearCacheRequest): The request body containing optional modes to clear.
1287
+
1288
+ Returns:
1289
+ ClearCacheResponse: A response object containing the status and message.
1290
+
1291
+ Raises:
1292
+ HTTPException: If an error occurs during cache clearing (400 for invalid modes, 500 for other errors).
1293
+ """
1294
+ try:
1295
+ # Validate modes if provided
1296
+ valid_modes = ["default", "naive", "local", "global", "hybrid", "mix"]
1297
+ if request.modes and not all(mode in valid_modes for mode in request.modes):
1298
+ invalid_modes = [
1299
+ mode for mode in request.modes if mode not in valid_modes
1300
+ ]
1301
+ raise HTTPException(
1302
+ status_code=400,
1303
+ detail=f"Invalid mode(s): {invalid_modes}. Valid modes are: {valid_modes}",
1304
+ )
1305
+
1306
+ # Call the aclear_cache method
1307
+ await rag.aclear_cache(request.modes)
1308
+
1309
+ # Prepare success message
1310
+ if request.modes:
1311
+ message = f"Successfully cleared cache for modes: {request.modes}"
1312
+ else:
1313
+ message = "Successfully cleared all cache"
1314
+
1315
+ return ClearCacheResponse(status="success", message=message)
1316
+ except HTTPException:
1317
+ # Re-raise HTTP exceptions
1318
+ raise
1319
+ except Exception as e:
1320
+ logger.error(f"Error clearing cache: {str(e)}")
1321
+ logger.error(traceback.format_exc())
1322
+ raise HTTPException(status_code=500, detail=str(e))
1323
+
1324
  return router
lightrag/api/routers/graph_routes.py CHANGED
@@ -3,7 +3,7 @@ This module contains all graph-related routes for the LightRAG API.
3
  """
4
 
5
  from typing import Optional
6
- from fastapi import APIRouter, Depends
7
 
8
  from ..utils_api import get_combined_auth_dependency
9
 
@@ -25,23 +25,20 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
25
 
26
  @router.get("/graphs", dependencies=[Depends(combined_auth)])
27
  async def get_knowledge_graph(
28
- label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
 
 
29
  ):
30
  """
31
  Retrieve a connected subgraph of nodes where the label includes the specified label.
32
- Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
33
  When reducing the number of nodes, the prioritization criteria are as follows:
34
- 1. min_degree does not affect nodes directly connected to the matching nodes
35
- 2. Label matching nodes take precedence
36
- 3. Followed by nodes directly connected to the matching nodes
37
- 4. Finally, the degree of the nodes
38
- Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
39
 
40
  Args:
41
- label (str): Label to get knowledge graph for
42
- max_depth (int, optional): Maximum depth of graph. Defaults to 3.
43
- inclusive_search (bool, optional): If True, search for nodes that include the label. Defaults to False.
44
- min_degree (int, optional): Minimum degree of nodes. Defaults to 0.
45
 
46
  Returns:
47
  Dict[str, List[str]]: Knowledge graph for label
@@ -49,8 +46,7 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
49
  return await rag.get_knowledge_graph(
50
  node_label=label,
51
  max_depth=max_depth,
52
- inclusive=inclusive,
53
- min_degree=min_degree,
54
  )
55
 
56
  return router
 
3
  """
4
 
5
  from typing import Optional
6
+ from fastapi import APIRouter, Depends, Query
7
 
8
  from ..utils_api import get_combined_auth_dependency
9
 
 
25
 
26
  @router.get("/graphs", dependencies=[Depends(combined_auth)])
27
  async def get_knowledge_graph(
28
+ label: str = Query(..., description="Label to get knowledge graph for"),
29
+ max_depth: int = Query(3, description="Maximum depth of graph", ge=1),
30
+ max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1),
31
  ):
32
  """
33
  Retrieve a connected subgraph of nodes where the label includes the specified label.
 
34
  When reducing the number of nodes, the prioritization criteria are as follows:
35
+ 1. Hops(path) to the staring node take precedence
36
+ 2. Followed by the degree of the nodes
 
 
 
37
 
38
  Args:
39
+ label (str): Label of the starting node
40
+ max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3
41
+ max_nodes: Maxiumu nodes to return
 
42
 
43
  Returns:
44
  Dict[str, List[str]]: Knowledge graph for label
 
46
  return await rag.get_knowledge_graph(
47
  node_label=label,
48
  max_depth=max_depth,
49
+ max_nodes=max_nodes,
 
50
  )
51
 
52
  return router
lightrag/api/run_with_gunicorn.py CHANGED
@@ -7,14 +7,9 @@ import os
7
  import sys
8
  import signal
9
  import pipmaster as pm
10
- from lightrag.api.utils_api import parse_args, display_splash_screen, check_env_file
11
  from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
12
- from dotenv import load_dotenv
13
-
14
- # use the .env that is inside the current folder
15
- # allows to use different .env file for each lightrag instance
16
- # the OS environment variables take precedence over the .env file
17
- load_dotenv(dotenv_path=".env", override=False)
18
 
19
 
20
  def check_and_install_dependencies():
@@ -59,20 +54,17 @@ def main():
59
  signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
60
  signal.signal(signal.SIGTERM, signal_handler) # kill command
61
 
62
- # Parse all arguments using parse_args
63
- args = parse_args(is_uvicorn_mode=False)
64
-
65
  # Display startup information
66
- display_splash_screen(args)
67
 
68
  print("🚀 Starting LightRAG with Gunicorn")
69
- print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
70
  print("🔍 Preloading app: Enabled")
71
  print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
72
  print("\n\n" + "=" * 80)
73
  print("MAIN PROCESS INITIALIZATION")
74
  print(f"Process ID: {os.getpid()}")
75
- print(f"Workers setting: {args.workers}")
76
  print("=" * 80 + "\n")
77
 
78
  # Import Gunicorn's StandaloneApplication
@@ -128,31 +120,43 @@ def main():
128
 
129
  # Set configuration variables in gunicorn_config, prioritizing command line arguments
130
  gunicorn_config.workers = (
131
- args.workers if args.workers else int(os.getenv("WORKERS", 1))
 
 
132
  )
133
 
134
  # Bind configuration prioritizes command line arguments
135
- host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
136
- port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
 
 
 
 
 
 
 
 
137
  gunicorn_config.bind = f"{host}:{port}"
138
 
139
  # Log level configuration prioritizes command line arguments
140
  gunicorn_config.loglevel = (
141
- args.log_level.lower()
142
- if args.log_level
143
  else os.getenv("LOG_LEVEL", "info")
144
  )
145
 
146
  # Timeout configuration prioritizes command line arguments
147
  gunicorn_config.timeout = (
148
- args.timeout if args.timeout * 2 else int(os.getenv("TIMEOUT", 150 * 2))
 
 
149
  )
150
 
151
  # Keepalive configuration
152
  gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
153
 
154
  # SSL configuration prioritizes command line arguments
155
- if args.ssl or os.getenv("SSL", "").lower() in (
156
  "true",
157
  "1",
158
  "yes",
@@ -160,12 +164,14 @@ def main():
160
  "on",
161
  ):
162
  gunicorn_config.certfile = (
163
- args.ssl_certfile
164
- if args.ssl_certfile
165
  else os.getenv("SSL_CERTFILE")
166
  )
167
  gunicorn_config.keyfile = (
168
- args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
 
 
169
  )
170
 
171
  # Set configuration options from the module
@@ -190,13 +196,13 @@ def main():
190
  # Import the application
191
  from lightrag.api.lightrag_server import get_application
192
 
193
- return get_application(args)
194
 
195
  # Create the application
196
  app = GunicornApp("")
197
 
198
  # Force workers to be an integer and greater than 1 for multi-process mode
199
- workers_count = int(args.workers)
200
  if workers_count > 1:
201
  # Set a flag to indicate we're in the main process
202
  os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
 
7
  import sys
8
  import signal
9
  import pipmaster as pm
10
+ from lightrag.api.utils_api import display_splash_screen, check_env_file
11
  from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
12
+ from .config import global_args
 
 
 
 
 
13
 
14
 
15
  def check_and_install_dependencies():
 
54
  signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
55
  signal.signal(signal.SIGTERM, signal_handler) # kill command
56
 
 
 
 
57
  # Display startup information
58
+ display_splash_screen(global_args)
59
 
60
  print("🚀 Starting LightRAG with Gunicorn")
61
+ print(f"🔄 Worker management: Gunicorn (workers={global_args.workers})")
62
  print("🔍 Preloading app: Enabled")
63
  print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
64
  print("\n\n" + "=" * 80)
65
  print("MAIN PROCESS INITIALIZATION")
66
  print(f"Process ID: {os.getpid()}")
67
+ print(f"Workers setting: {global_args.workers}")
68
  print("=" * 80 + "\n")
69
 
70
  # Import Gunicorn's StandaloneApplication
 
120
 
121
  # Set configuration variables in gunicorn_config, prioritizing command line arguments
122
  gunicorn_config.workers = (
123
+ global_args.workers
124
+ if global_args.workers
125
+ else int(os.getenv("WORKERS", 1))
126
  )
127
 
128
  # Bind configuration prioritizes command line arguments
129
+ host = (
130
+ global_args.host
131
+ if global_args.host != "0.0.0.0"
132
+ else os.getenv("HOST", "0.0.0.0")
133
+ )
134
+ port = (
135
+ global_args.port
136
+ if global_args.port != 9621
137
+ else int(os.getenv("PORT", 9621))
138
+ )
139
  gunicorn_config.bind = f"{host}:{port}"
140
 
141
  # Log level configuration prioritizes command line arguments
142
  gunicorn_config.loglevel = (
143
+ global_args.log_level.lower()
144
+ if global_args.log_level
145
  else os.getenv("LOG_LEVEL", "info")
146
  )
147
 
148
  # Timeout configuration prioritizes command line arguments
149
  gunicorn_config.timeout = (
150
+ global_args.timeout
151
+ if global_args.timeout * 2
152
+ else int(os.getenv("TIMEOUT", 150 * 2))
153
  )
154
 
155
  # Keepalive configuration
156
  gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
157
 
158
  # SSL configuration prioritizes command line arguments
159
+ if global_args.ssl or os.getenv("SSL", "").lower() in (
160
  "true",
161
  "1",
162
  "yes",
 
164
  "on",
165
  ):
166
  gunicorn_config.certfile = (
167
+ global_args.ssl_certfile
168
+ if global_args.ssl_certfile
169
  else os.getenv("SSL_CERTFILE")
170
  )
171
  gunicorn_config.keyfile = (
172
+ global_args.ssl_keyfile
173
+ if global_args.ssl_keyfile
174
+ else os.getenv("SSL_KEYFILE")
175
  )
176
 
177
  # Set configuration options from the module
 
196
  # Import the application
197
  from lightrag.api.lightrag_server import get_application
198
 
199
+ return get_application(global_args)
200
 
201
  # Create the application
202
  app = GunicornApp("")
203
 
204
  # Force workers to be an integer and greater than 1 for multi-process mode
205
+ workers_count = int(global_args.workers)
206
  if workers_count > 1:
207
  # Set a flag to indicate we're in the main process
208
  os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
lightrag/api/utils_api.py CHANGED
@@ -7,15 +7,13 @@ import argparse
7
  from typing import Optional, List, Tuple
8
  import sys
9
  from ascii_colors import ASCIIColors
10
- import logging
11
  from lightrag.api import __api_version__ as api_version
12
  from lightrag import __version__ as core_version
13
  from fastapi import HTTPException, Security, Request, status
14
- from dotenv import load_dotenv
15
  from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
16
  from starlette.status import HTTP_403_FORBIDDEN
17
  from .auth import auth_handler
18
- from ..prompt import PROMPTS
19
 
20
 
21
  def check_env_file():
@@ -36,16 +34,8 @@ def check_env_file():
36
  return True
37
 
38
 
39
- # use the .env that is inside the current folder
40
- # allows to use different .env file for each lightrag instance
41
- # the OS environment variables take precedence over the .env file
42
- load_dotenv(dotenv_path=".env", override=False)
43
-
44
- global_args = {"main_args": None}
45
-
46
- # Get whitelist paths from environment variable, only once during initialization
47
- default_whitelist = "/health,/api/*"
48
- whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
49
 
50
  # Pre-compile path matching patterns
51
  whitelist_patterns: List[Tuple[str, bool]] = []
@@ -63,19 +53,6 @@ for path in whitelist_paths:
63
  auth_configured = bool(auth_handler.accounts)
64
 
65
 
66
- class OllamaServerInfos:
67
- # Constants for emulated Ollama model information
68
- LIGHTRAG_NAME = "lightrag"
69
- LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
70
- LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
71
- LIGHTRAG_SIZE = 7365960935 # it's a dummy value
72
- LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
73
- LIGHTRAG_DIGEST = "sha256:lightrag"
74
-
75
-
76
- ollama_server_infos = OllamaServerInfos()
77
-
78
-
79
  def get_combined_auth_dependency(api_key: Optional[str] = None):
80
  """
81
  Create a combined authentication dependency that implements authentication logic
@@ -186,299 +163,6 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
186
  return combined_dependency
187
 
188
 
189
- class DefaultRAGStorageConfig:
190
- KV_STORAGE = "JsonKVStorage"
191
- VECTOR_STORAGE = "NanoVectorDBStorage"
192
- GRAPH_STORAGE = "NetworkXStorage"
193
- DOC_STATUS_STORAGE = "JsonDocStatusStorage"
194
-
195
-
196
- def get_default_host(binding_type: str) -> str:
197
- default_hosts = {
198
- "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
199
- "lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
200
- "azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
201
- "openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
202
- }
203
- return default_hosts.get(
204
- binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
205
- ) # fallback to ollama if unknown
206
-
207
-
208
- def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
209
- """
210
- Get value from environment variable with type conversion
211
-
212
- Args:
213
- env_key (str): Environment variable key
214
- default (any): Default value if env variable is not set
215
- value_type (type): Type to convert the value to
216
-
217
- Returns:
218
- any: Converted value from environment or default
219
- """
220
- value = os.getenv(env_key)
221
- if value is None:
222
- return default
223
-
224
- if value_type is bool:
225
- return value.lower() in ("true", "1", "yes", "t", "on")
226
- try:
227
- return value_type(value)
228
- except ValueError:
229
- return default
230
-
231
-
232
- def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
233
- """
234
- Parse command line arguments with environment variable fallback
235
-
236
- Args:
237
- is_uvicorn_mode: Whether running under uvicorn mode
238
-
239
- Returns:
240
- argparse.Namespace: Parsed arguments
241
- """
242
-
243
- parser = argparse.ArgumentParser(
244
- description="LightRAG FastAPI Server with separate working and input directories"
245
- )
246
-
247
- # Server configuration
248
- parser.add_argument(
249
- "--host",
250
- default=get_env_value("HOST", "0.0.0.0"),
251
- help="Server host (default: from env or 0.0.0.0)",
252
- )
253
- parser.add_argument(
254
- "--port",
255
- type=int,
256
- default=get_env_value("PORT", 9621, int),
257
- help="Server port (default: from env or 9621)",
258
- )
259
-
260
- # Directory configuration
261
- parser.add_argument(
262
- "--working-dir",
263
- default=get_env_value("WORKING_DIR", "./rag_storage"),
264
- help="Working directory for RAG storage (default: from env or ./rag_storage)",
265
- )
266
- parser.add_argument(
267
- "--input-dir",
268
- default=get_env_value("INPUT_DIR", "./inputs"),
269
- help="Directory containing input documents (default: from env or ./inputs)",
270
- )
271
-
272
- def timeout_type(value):
273
- if value is None:
274
- return 150
275
- if value is None or value == "None":
276
- return None
277
- return int(value)
278
-
279
- parser.add_argument(
280
- "--timeout",
281
- default=get_env_value("TIMEOUT", None, timeout_type),
282
- type=timeout_type,
283
- help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
284
- )
285
-
286
- # RAG configuration
287
- parser.add_argument(
288
- "--max-async",
289
- type=int,
290
- default=get_env_value("MAX_ASYNC", 4, int),
291
- help="Maximum async operations (default: from env or 4)",
292
- )
293
- parser.add_argument(
294
- "--max-tokens",
295
- type=int,
296
- default=get_env_value("MAX_TOKENS", 32768, int),
297
- help="Maximum token size (default: from env or 32768)",
298
- )
299
-
300
- # Logging configuration
301
- parser.add_argument(
302
- "--log-level",
303
- default=get_env_value("LOG_LEVEL", "INFO"),
304
- choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
305
- help="Logging level (default: from env or INFO)",
306
- )
307
- parser.add_argument(
308
- "--verbose",
309
- action="store_true",
310
- default=get_env_value("VERBOSE", False, bool),
311
- help="Enable verbose debug output(only valid for DEBUG log-level)",
312
- )
313
-
314
- parser.add_argument(
315
- "--key",
316
- type=str,
317
- default=get_env_value("LIGHTRAG_API_KEY", None),
318
- help="API key for authentication. This protects lightrag server against unauthorized access",
319
- )
320
-
321
- # Optional https parameters
322
- parser.add_argument(
323
- "--ssl",
324
- action="store_true",
325
- default=get_env_value("SSL", False, bool),
326
- help="Enable HTTPS (default: from env or False)",
327
- )
328
- parser.add_argument(
329
- "--ssl-certfile",
330
- default=get_env_value("SSL_CERTFILE", None),
331
- help="Path to SSL certificate file (required if --ssl is enabled)",
332
- )
333
- parser.add_argument(
334
- "--ssl-keyfile",
335
- default=get_env_value("SSL_KEYFILE", None),
336
- help="Path to SSL private key file (required if --ssl is enabled)",
337
- )
338
-
339
- parser.add_argument(
340
- "--history-turns",
341
- type=int,
342
- default=get_env_value("HISTORY_TURNS", 3, int),
343
- help="Number of conversation history turns to include (default: from env or 3)",
344
- )
345
-
346
- # Search parameters
347
- parser.add_argument(
348
- "--top-k",
349
- type=int,
350
- default=get_env_value("TOP_K", 60, int),
351
- help="Number of most similar results to return (default: from env or 60)",
352
- )
353
- parser.add_argument(
354
- "--cosine-threshold",
355
- type=float,
356
- default=get_env_value("COSINE_THRESHOLD", 0.2, float),
357
- help="Cosine similarity threshold (default: from env or 0.4)",
358
- )
359
-
360
- # Ollama model name
361
- parser.add_argument(
362
- "--simulated-model-name",
363
- type=str,
364
- default=get_env_value(
365
- "SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
366
- ),
367
- help="Number of conversation history turns to include (default: from env or 3)",
368
- )
369
-
370
- # Namespace
371
- parser.add_argument(
372
- "--namespace-prefix",
373
- type=str,
374
- default=get_env_value("NAMESPACE_PREFIX", ""),
375
- help="Prefix of the namespace",
376
- )
377
-
378
- parser.add_argument(
379
- "--auto-scan-at-startup",
380
- action="store_true",
381
- default=False,
382
- help="Enable automatic scanning when the program starts",
383
- )
384
-
385
- # Server workers configuration
386
- parser.add_argument(
387
- "--workers",
388
- type=int,
389
- default=get_env_value("WORKERS", 1, int),
390
- help="Number of worker processes (default: from env or 1)",
391
- )
392
-
393
- # LLM and embedding bindings
394
- parser.add_argument(
395
- "--llm-binding",
396
- type=str,
397
- default=get_env_value("LLM_BINDING", "ollama"),
398
- choices=["lollms", "ollama", "openai", "openai-ollama", "azure_openai"],
399
- help="LLM binding type (default: from env or ollama)",
400
- )
401
- parser.add_argument(
402
- "--embedding-binding",
403
- type=str,
404
- default=get_env_value("EMBEDDING_BINDING", "ollama"),
405
- choices=["lollms", "ollama", "openai", "azure_openai"],
406
- help="Embedding binding type (default: from env or ollama)",
407
- )
408
-
409
- args = parser.parse_args()
410
-
411
- # If in uvicorn mode and workers > 1, force it to 1 and log warning
412
- if is_uvicorn_mode and args.workers > 1:
413
- original_workers = args.workers
414
- args.workers = 1
415
- # Log warning directly here
416
- logging.warning(
417
- f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
418
- )
419
-
420
- # convert relative path to absolute path
421
- args.working_dir = os.path.abspath(args.working_dir)
422
- args.input_dir = os.path.abspath(args.input_dir)
423
-
424
- # Inject storage configuration from environment variables
425
- args.kv_storage = get_env_value(
426
- "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
427
- )
428
- args.doc_status_storage = get_env_value(
429
- "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
430
- )
431
- args.graph_storage = get_env_value(
432
- "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
433
- )
434
- args.vector_storage = get_env_value(
435
- "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
436
- )
437
-
438
- # Get MAX_PARALLEL_INSERT from environment
439
- args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
440
-
441
- # Handle openai-ollama special case
442
- if args.llm_binding == "openai-ollama":
443
- args.llm_binding = "openai"
444
- args.embedding_binding = "ollama"
445
-
446
- args.llm_binding_host = get_env_value(
447
- "LLM_BINDING_HOST", get_default_host(args.llm_binding)
448
- )
449
- args.embedding_binding_host = get_env_value(
450
- "EMBEDDING_BINDING_HOST", get_default_host(args.embedding_binding)
451
- )
452
- args.llm_binding_api_key = get_env_value("LLM_BINDING_API_KEY", None)
453
- args.embedding_binding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
454
-
455
- # Inject model configuration
456
- args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
457
- args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
458
- args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
459
- args.max_embed_tokens = get_env_value("MAX_EMBED_TOKENS", 8192, int)
460
-
461
- # Inject chunk configuration
462
- args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
463
- args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
464
-
465
- # Inject LLM cache configuration
466
- args.enable_llm_cache_for_extract = get_env_value(
467
- "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
468
- )
469
-
470
- # Inject LLM temperature configuration
471
- args.temperature = get_env_value("TEMPERATURE", 0.5, float)
472
-
473
- # Select Document loading tool (DOCLING, DEFAULT)
474
- args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
475
-
476
- ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
477
-
478
- global_args["main_args"] = args
479
- return args
480
-
481
-
482
  def display_splash_screen(args: argparse.Namespace) -> None:
483
  """
484
  Display a colorful splash screen showing LightRAG server configuration
@@ -489,7 +173,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
489
  # Banner
490
  ASCIIColors.cyan(f"""
491
  ╔══════════════════════════════════════════════════════════════╗
492
- 🚀 LightRAG Server v{core_version}/{api_version}
493
  ║ Fast, Lightweight RAG Server Implementation ║
494
  ╚══════════════════════════════════════════════════════════════╝
495
  """)
@@ -503,7 +187,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
503
  ASCIIColors.white(" ├─ Workers: ", end="")
504
  ASCIIColors.yellow(f"{args.workers}")
505
  ASCIIColors.white(" ├─ CORS Origins: ", end="")
506
- ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
507
  ASCIIColors.white(" ├─ SSL Enabled: ", end="")
508
  ASCIIColors.yellow(f"{args.ssl}")
509
  if args.ssl:
@@ -519,8 +203,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
519
  ASCIIColors.yellow(f"{args.verbose}")
520
  ASCIIColors.white(" ├─ History Turns: ", end="")
521
  ASCIIColors.yellow(f"{args.history_turns}")
522
- ASCIIColors.white(" └─ API Key: ", end="")
523
  ASCIIColors.yellow("Set" if args.key else "Not Set")
 
 
524
 
525
  # Directory Configuration
526
  ASCIIColors.magenta("\n📂 Directory Configuration:")
@@ -558,10 +244,9 @@ def display_splash_screen(args: argparse.Namespace) -> None:
558
  ASCIIColors.yellow(f"{args.embedding_dim}")
559
 
560
  # RAG Configuration
561
- summary_language = os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"])
562
  ASCIIColors.magenta("\n⚙️ RAG Configuration:")
563
  ASCIIColors.white(" ├─ Summary Language: ", end="")
564
- ASCIIColors.yellow(f"{summary_language}")
565
  ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
566
  ASCIIColors.yellow(f"{args.max_parallel_insert}")
567
  ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
@@ -595,19 +280,17 @@ def display_splash_screen(args: argparse.Namespace) -> None:
595
  protocol = "https" if args.ssl else "http"
596
  if args.host == "0.0.0.0":
597
  ASCIIColors.magenta("\n🌐 Server Access Information:")
598
- ASCIIColors.white(" ├─ Local Access: ", end="")
599
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
600
  ASCIIColors.white(" ├─ Remote Access: ", end="")
601
  ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
602
  ASCIIColors.white(" ├─ API Documentation (local): ", end="")
603
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
604
- ASCIIColors.white(" ├─ Alternative Documentation (local): ", end="")
605
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
606
- ASCIIColors.white(" └─ WebUI (local): ", end="")
607
- ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui")
608
 
609
- ASCIIColors.yellow("\n📝 Note:")
610
- ASCIIColors.white(""" Since the server is running on 0.0.0.0:
611
  - Use 'localhost' or '127.0.0.1' for local access
612
  - Use your machine's IP address for remote access
613
  - To find your IP address:
@@ -617,42 +300,24 @@ def display_splash_screen(args: argparse.Namespace) -> None:
617
  else:
618
  base_url = f"{protocol}://{args.host}:{args.port}"
619
  ASCIIColors.magenta("\n🌐 Server Access Information:")
620
- ASCIIColors.white(" ├─ Base URL: ", end="")
621
  ASCIIColors.yellow(f"{base_url}")
622
  ASCIIColors.white(" ├─ API Documentation: ", end="")
623
  ASCIIColors.yellow(f"{base_url}/docs")
624
  ASCIIColors.white(" └─ Alternative Documentation: ", end="")
625
  ASCIIColors.yellow(f"{base_url}/redoc")
626
 
627
- # Usage Examples
628
- ASCIIColors.magenta("\n📚 Quick Start Guide:")
629
- ASCIIColors.cyan("""
630
- 1. Access the Swagger UI:
631
- Open your browser and navigate to the API documentation URL above
632
-
633
- 2. API Authentication:""")
634
- if args.key:
635
- ASCIIColors.cyan(""" Add the following header to your requests:
636
- X-API-Key: <your-api-key>
637
- """)
638
- else:
639
- ASCIIColors.cyan(" No authentication required\n")
640
-
641
- ASCIIColors.cyan(""" 3. Basic Operations:
642
- - POST /upload_document: Upload new documents to RAG
643
- - POST /query: Query your document collection
644
-
645
- 4. Monitor the server:
646
- - Check server logs for detailed operation information
647
- - Use healthcheck endpoint: GET /health
648
- """)
649
-
650
  # Security Notice
651
  if args.key:
652
  ASCIIColors.yellow("\n⚠️ Security Notice:")
653
  ASCIIColors.white(""" API Key authentication is enabled.
654
  Make sure to include the X-API-Key header in all your requests.
655
  """)
 
 
 
 
 
656
 
657
  # Ensure splash output flush to system log
658
  sys.stdout.flush()
 
7
  from typing import Optional, List, Tuple
8
  import sys
9
  from ascii_colors import ASCIIColors
 
10
  from lightrag.api import __api_version__ as api_version
11
  from lightrag import __version__ as core_version
12
  from fastapi import HTTPException, Security, Request, status
 
13
  from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
14
  from starlette.status import HTTP_403_FORBIDDEN
15
  from .auth import auth_handler
16
+ from .config import ollama_server_infos, global_args
17
 
18
 
19
  def check_env_file():
 
34
  return True
35
 
36
 
37
+ # Get whitelist paths from global_args, only once during initialization
38
+ whitelist_paths = global_args.whitelist_paths.split(",")
 
 
 
 
 
 
 
 
39
 
40
  # Pre-compile path matching patterns
41
  whitelist_patterns: List[Tuple[str, bool]] = []
 
53
  auth_configured = bool(auth_handler.accounts)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def get_combined_auth_dependency(api_key: Optional[str] = None):
57
  """
58
  Create a combined authentication dependency that implements authentication logic
 
163
  return combined_dependency
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def display_splash_screen(args: argparse.Namespace) -> None:
167
  """
168
  Display a colorful splash screen showing LightRAG server configuration
 
173
  # Banner
174
  ASCIIColors.cyan(f"""
175
  ╔══════════════════════════════════════════════════════════════╗
176
+ 🚀 LightRAG Server v{core_version}/{api_version}
177
  ║ Fast, Lightweight RAG Server Implementation ║
178
  ╚══════════════════════════════════════════════════════════════╝
179
  """)
 
187
  ASCIIColors.white(" ├─ Workers: ", end="")
188
  ASCIIColors.yellow(f"{args.workers}")
189
  ASCIIColors.white(" ├─ CORS Origins: ", end="")
190
+ ASCIIColors.yellow(f"{args.cors_origins}")
191
  ASCIIColors.white(" ├─ SSL Enabled: ", end="")
192
  ASCIIColors.yellow(f"{args.ssl}")
193
  if args.ssl:
 
203
  ASCIIColors.yellow(f"{args.verbose}")
204
  ASCIIColors.white(" ├─ History Turns: ", end="")
205
  ASCIIColors.yellow(f"{args.history_turns}")
206
+ ASCIIColors.white(" ├─ API Key: ", end="")
207
  ASCIIColors.yellow("Set" if args.key else "Not Set")
208
+ ASCIIColors.white(" └─ JWT Auth: ", end="")
209
+ ASCIIColors.yellow("Enabled" if args.auth_accounts else "Disabled")
210
 
211
  # Directory Configuration
212
  ASCIIColors.magenta("\n📂 Directory Configuration:")
 
244
  ASCIIColors.yellow(f"{args.embedding_dim}")
245
 
246
  # RAG Configuration
 
247
  ASCIIColors.magenta("\n⚙️ RAG Configuration:")
248
  ASCIIColors.white(" ├─ Summary Language: ", end="")
249
+ ASCIIColors.yellow(f"{args.summary_language}")
250
  ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
251
  ASCIIColors.yellow(f"{args.max_parallel_insert}")
252
  ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
 
280
  protocol = "https" if args.ssl else "http"
281
  if args.host == "0.0.0.0":
282
  ASCIIColors.magenta("\n🌐 Server Access Information:")
283
+ ASCIIColors.white(" ├─ WebUI (local): ", end="")
284
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
285
  ASCIIColors.white(" ├─ Remote Access: ", end="")
286
  ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
287
  ASCIIColors.white(" ├─ API Documentation (local): ", end="")
288
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
289
+ ASCIIColors.white(" └─ Alternative Documentation (local): ", end="")
290
  ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
 
 
291
 
292
+ ASCIIColors.magenta("\n📝 Note:")
293
+ ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
294
  - Use 'localhost' or '127.0.0.1' for local access
295
  - Use your machine's IP address for remote access
296
  - To find your IP address:
 
300
  else:
301
  base_url = f"{protocol}://{args.host}:{args.port}"
302
  ASCIIColors.magenta("\n🌐 Server Access Information:")
303
+ ASCIIColors.white(" ├─ WebUI (local): ", end="")
304
  ASCIIColors.yellow(f"{base_url}")
305
  ASCIIColors.white(" ├─ API Documentation: ", end="")
306
  ASCIIColors.yellow(f"{base_url}/docs")
307
  ASCIIColors.white(" └─ Alternative Documentation: ", end="")
308
  ASCIIColors.yellow(f"{base_url}/redoc")
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  # Security Notice
311
  if args.key:
312
  ASCIIColors.yellow("\n⚠️ Security Notice:")
313
  ASCIIColors.white(""" API Key authentication is enabled.
314
  Make sure to include the X-API-Key header in all your requests.
315
  """)
316
+ if args.auth_accounts:
317
+ ASCIIColors.yellow("\n⚠️ Security Notice:")
318
+ ASCIIColors.white(""" JWT authentication is enabled.
319
+ Make sure to login before making the request, and include the 'Authorization' in the header.
320
+ """)
321
 
322
  # Ensure splash output flush to system log
323
  sys.stdout.flush()
lightrag/api/webui/assets/index-CD5HxTy1.css DELETED
Binary file (55.1 kB)
 
lightrag/api/webui/assets/{index-raheqJeu.js → index-Cma7xY0-.js} RENAMED
Binary files a/lightrag/api/webui/assets/index-raheqJeu.js and b/lightrag/api/webui/assets/index-Cma7xY0-.js differ
 
lightrag/api/webui/assets/index-QU59h9JG.css ADDED
Binary file (57.1 kB). View file
 
lightrag/api/webui/index.html CHANGED
Binary files a/lightrag/api/webui/index.html and b/lightrag/api/webui/index.html differ
 
lightrag/base.py CHANGED
@@ -112,6 +112,32 @@ class StorageNameSpace(ABC):
112
  async def index_done_callback(self) -> None:
113
  """Commit the storage operations after indexing"""
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  @dataclass
117
  class BaseVectorStorage(StorageNameSpace, ABC):
@@ -127,15 +153,33 @@ class BaseVectorStorage(StorageNameSpace, ABC):
127
 
128
  @abstractmethod
129
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
130
- """Insert or update vectors in the storage."""
 
 
 
 
 
 
131
 
132
  @abstractmethod
133
  async def delete_entity(self, entity_name: str) -> None:
134
- """Delete a single entity by its name."""
 
 
 
 
 
 
135
 
136
  @abstractmethod
137
  async def delete_entity_relation(self, entity_name: str) -> None:
138
- """Delete relations for a given entity."""
 
 
 
 
 
 
139
 
140
  @abstractmethod
141
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -161,6 +205,19 @@ class BaseVectorStorage(StorageNameSpace, ABC):
161
  """
162
  pass
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  @dataclass
166
  class BaseKVStorage(StorageNameSpace, ABC):
@@ -180,7 +237,42 @@ class BaseKVStorage(StorageNameSpace, ABC):
180
 
181
  @abstractmethod
182
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
183
- """Upsert data"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
 
186
  @dataclass
@@ -205,13 +297,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
205
 
206
  @abstractmethod
207
  async def get_node(self, node_id: str) -> dict[str, str] | None:
208
- """Get an edge by its source and target node ids."""
209
 
210
  @abstractmethod
211
  async def get_edge(
212
  self, source_node_id: str, target_node_id: str
213
  ) -> dict[str, str] | None:
214
- """Get all edges connected to a node."""
215
 
216
  @abstractmethod
217
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
@@ -225,7 +317,13 @@ class BaseGraphStorage(StorageNameSpace, ABC):
225
  async def upsert_edge(
226
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
227
  ) -> None:
228
- """Delete a node from the graph."""
 
 
 
 
 
 
229
 
230
  @abstractmethod
231
  async def delete_node(self, node_id: str) -> None:
@@ -243,9 +341,20 @@ class BaseGraphStorage(StorageNameSpace, ABC):
243
 
244
  @abstractmethod
245
  async def get_knowledge_graph(
246
- self, node_label: str, max_depth: int = 3
247
  ) -> KnowledgeGraph:
248
- """Retrieve a subgraph of the knowledge graph starting from a given node."""
 
 
 
 
 
 
 
 
 
 
 
249
 
250
 
251
  class DocStatus(str, Enum):
@@ -297,6 +406,10 @@ class DocStatusStorage(BaseKVStorage, ABC):
297
  ) -> dict[str, DocProcessingStatus]:
298
  """Get all documents with a specific status"""
299
 
 
 
 
 
300
 
301
  class StoragesStatus(str, Enum):
302
  """Storages status"""
 
112
  async def index_done_callback(self) -> None:
113
  """Commit the storage operations after indexing"""
114
 
115
+ @abstractmethod
116
+ async def drop(self) -> dict[str, str]:
117
+ """Drop all data from storage and clean up resources
118
+
119
+ This abstract method defines the contract for dropping all data from a storage implementation.
120
+ Each storage type must implement this method to:
121
+ 1. Clear all data from memory and/or external storage
122
+ 2. Remove any associated storage files if applicable
123
+ 3. Reset the storage to its initial state
124
+ 4. Handle cleanup of any resources
125
+ 5. Notify other processes if necessary
126
+ 6. This action should persistent the data to disk immediately.
127
+
128
+ Returns:
129
+ dict[str, str]: Operation status and message with the following format:
130
+ {
131
+ "status": str, # "success" or "error"
132
+ "message": str # "data dropped" on success, error details on failure
133
+ }
134
+
135
+ Implementation specific:
136
+ - On success: return {"status": "success", "message": "data dropped"}
137
+ - On failure: return {"status": "error", "message": "<error details>"}
138
+ - If not supported: return {"status": "error", "message": "unsupported"}
139
+ """
140
+
141
 
142
  @dataclass
143
  class BaseVectorStorage(StorageNameSpace, ABC):
 
153
 
154
  @abstractmethod
155
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
156
+ """Insert or update vectors in the storage.
157
+
158
+ Importance notes for in-memory storage:
159
+ 1. Changes will be persisted to disk during the next index_done_callback
160
+ 2. Only one process should updating the storage at a time before index_done_callback,
161
+ KG-storage-log should be used to avoid data corruption
162
+ """
163
 
164
  @abstractmethod
165
  async def delete_entity(self, entity_name: str) -> None:
166
+ """Delete a single entity by its name.
167
+
168
+ Importance notes for in-memory storage:
169
+ 1. Changes will be persisted to disk during the next index_done_callback
170
+ 2. Only one process should updating the storage at a time before index_done_callback,
171
+ KG-storage-log should be used to avoid data corruption
172
+ """
173
 
174
  @abstractmethod
175
  async def delete_entity_relation(self, entity_name: str) -> None:
176
+ """Delete relations for a given entity.
177
+
178
+ Importance notes for in-memory storage:
179
+ 1. Changes will be persisted to disk during the next index_done_callback
180
+ 2. Only one process should updating the storage at a time before index_done_callback,
181
+ KG-storage-log should be used to avoid data corruption
182
+ """
183
 
184
  @abstractmethod
185
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
 
205
  """
206
  pass
207
 
208
+ @abstractmethod
209
+ async def delete(self, ids: list[str]):
210
+ """Delete vectors with specified IDs
211
+
212
+ Importance notes for in-memory storage:
213
+ 1. Changes will be persisted to disk during the next index_done_callback
214
+ 2. Only one process should updating the storage at a time before index_done_callback,
215
+ KG-storage-log should be used to avoid data corruption
216
+
217
+ Args:
218
+ ids: List of vector IDs to be deleted
219
+ """
220
+
221
 
222
  @dataclass
223
  class BaseKVStorage(StorageNameSpace, ABC):
 
237
 
238
  @abstractmethod
239
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
240
+ """Upsert data
241
+
242
+ Importance notes for in-memory storage:
243
+ 1. Changes will be persisted to disk during the next index_done_callback
244
+ 2. update flags to notify other processes that data persistence is needed
245
+ """
246
+
247
+ @abstractmethod
248
+ async def delete(self, ids: list[str]) -> None:
249
+ """Delete specific records from storage by their IDs
250
+
251
+ Importance notes for in-memory storage:
252
+ 1. Changes will be persisted to disk during the next index_done_callback
253
+ 2. update flags to notify other processes that data persistence is needed
254
+
255
+ Args:
256
+ ids (list[str]): List of document IDs to be deleted from storage
257
+
258
+ Returns:
259
+ None
260
+ """
261
+
262
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
263
+ """Delete specific records from storage by cache mode
264
+
265
+ Importance notes for in-memory storage:
266
+ 1. Changes will be persisted to disk during the next index_done_callback
267
+ 2. update flags to notify other processes that data persistence is needed
268
+
269
+ Args:
270
+ modes (list[str]): List of cache modes to be dropped from storage
271
+
272
+ Returns:
273
+ True: if the cache drop successfully
274
+ False: if the cache drop failed, or the cache mode is not supported
275
+ """
276
 
277
 
278
  @dataclass
 
297
 
298
  @abstractmethod
299
  async def get_node(self, node_id: str) -> dict[str, str] | None:
300
+ """Get node by its label identifier, return only node properties"""
301
 
302
  @abstractmethod
303
  async def get_edge(
304
  self, source_node_id: str, target_node_id: str
305
  ) -> dict[str, str] | None:
306
+ """Get edge properties between two nodes"""
307
 
308
  @abstractmethod
309
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
 
317
  async def upsert_edge(
318
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
319
  ) -> None:
320
+ """Delete a node from the graph.
321
+
322
+ Importance notes for in-memory storage:
323
+ 1. Changes will be persisted to disk during the next index_done_callback
324
+ 2. Only one process should updating the storage at a time before index_done_callback,
325
+ KG-storage-log should be used to avoid data corruption
326
+ """
327
 
328
  @abstractmethod
329
  async def delete_node(self, node_id: str) -> None:
 
341
 
342
  @abstractmethod
343
  async def get_knowledge_graph(
344
+ self, node_label: str, max_depth: int = 3, max_nodes: int = 1000
345
  ) -> KnowledgeGraph:
346
+ """
347
+ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
348
+
349
+ Args:
350
+ node_label: Label of the starting node,* means all nodes
351
+ max_depth: Maximum depth of the subgraph, Defaults to 3
352
+ max_nodes: Maxiumu nodes to return, Defaults to 1000(BFS if possible)
353
+
354
+ Returns:
355
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
356
+ indicating whether the graph was truncated due to max_nodes limit
357
+ """
358
 
359
 
360
  class DocStatus(str, Enum):
 
406
  ) -> dict[str, DocProcessingStatus]:
407
  """Get all documents with a specific status"""
408
 
409
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
410
+ """Drop cache is not supported for Doc Status storage"""
411
+ return False
412
+
413
 
414
  class StoragesStatus(str, Enum):
415
  """Storages status"""
lightrag/kg/__init__.py CHANGED
@@ -2,11 +2,10 @@ STORAGE_IMPLEMENTATIONS = {
2
  "KV_STORAGE": {
3
  "implementations": [
4
  "JsonKVStorage",
5
- "MongoKVStorage",
6
  "RedisKVStorage",
7
- "TiDBKVStorage",
8
  "PGKVStorage",
9
- "OracleKVStorage",
 
10
  ],
11
  "required_methods": ["get_by_id", "upsert"],
12
  },
@@ -14,12 +13,11 @@ STORAGE_IMPLEMENTATIONS = {
14
  "implementations": [
15
  "NetworkXStorage",
16
  "Neo4JStorage",
17
- "MongoGraphStorage",
18
- "TiDBGraphStorage",
19
- "AGEStorage",
20
- "GremlinStorage",
21
  "PGGraphStorage",
22
- "OracleGraphStorage",
 
 
 
23
  ],
24
  "required_methods": ["upsert_node", "upsert_edge"],
25
  },
@@ -28,12 +26,11 @@ STORAGE_IMPLEMENTATIONS = {
28
  "NanoVectorDBStorage",
29
  "MilvusVectorDBStorage",
30
  "ChromaVectorDBStorage",
31
- "TiDBVectorDBStorage",
32
  "PGVectorStorage",
33
  "FaissVectorDBStorage",
34
  "QdrantVectorDBStorage",
35
- "OracleVectorDBStorage",
36
  "MongoVectorDBStorage",
 
37
  ],
38
  "required_methods": ["query", "upsert"],
39
  },
@@ -41,7 +38,6 @@ STORAGE_IMPLEMENTATIONS = {
41
  "implementations": [
42
  "JsonDocStatusStorage",
43
  "PGDocStatusStorage",
44
- "PGDocStatusStorage",
45
  "MongoDocStatusStorage",
46
  ],
47
  "required_methods": ["get_docs_by_status"],
@@ -54,50 +50,32 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
54
  "JsonKVStorage": [],
55
  "MongoKVStorage": [],
56
  "RedisKVStorage": ["REDIS_URI"],
57
- "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
58
  "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
59
- "OracleKVStorage": [
60
- "ORACLE_DSN",
61
- "ORACLE_USER",
62
- "ORACLE_PASSWORD",
63
- "ORACLE_CONFIG_DIR",
64
- ],
65
  # Graph Storage Implementations
66
  "NetworkXStorage": [],
67
  "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
68
  "MongoGraphStorage": [],
69
- "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
70
  "AGEStorage": [
71
  "AGE_POSTGRES_DB",
72
  "AGE_POSTGRES_USER",
73
  "AGE_POSTGRES_PASSWORD",
74
  ],
75
- "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
76
  "PGGraphStorage": [
77
  "POSTGRES_USER",
78
  "POSTGRES_PASSWORD",
79
  "POSTGRES_DATABASE",
80
  ],
81
- "OracleGraphStorage": [
82
- "ORACLE_DSN",
83
- "ORACLE_USER",
84
- "ORACLE_PASSWORD",
85
- "ORACLE_CONFIG_DIR",
86
- ],
87
  # Vector Storage Implementations
88
  "NanoVectorDBStorage": [],
89
  "MilvusVectorDBStorage": [],
90
  "ChromaVectorDBStorage": [],
91
- "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
92
  "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
93
  "FaissVectorDBStorage": [],
94
  "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
95
- "OracleVectorDBStorage": [
96
- "ORACLE_DSN",
97
- "ORACLE_USER",
98
- "ORACLE_PASSWORD",
99
- "ORACLE_CONFIG_DIR",
100
- ],
101
  "MongoVectorDBStorage": [],
102
  # Document Status Storage Implementations
103
  "JsonDocStatusStorage": [],
@@ -112,9 +90,6 @@ STORAGES = {
112
  "NanoVectorDBStorage": ".kg.nano_vector_db_impl",
113
  "JsonDocStatusStorage": ".kg.json_doc_status_impl",
114
  "Neo4JStorage": ".kg.neo4j_impl",
115
- "OracleKVStorage": ".kg.oracle_impl",
116
- "OracleGraphStorage": ".kg.oracle_impl",
117
- "OracleVectorDBStorage": ".kg.oracle_impl",
118
  "MilvusVectorDBStorage": ".kg.milvus_impl",
119
  "MongoKVStorage": ".kg.mongo_impl",
120
  "MongoDocStatusStorage": ".kg.mongo_impl",
@@ -122,14 +97,14 @@ STORAGES = {
122
  "MongoVectorDBStorage": ".kg.mongo_impl",
123
  "RedisKVStorage": ".kg.redis_impl",
124
  "ChromaVectorDBStorage": ".kg.chroma_impl",
125
- "TiDBKVStorage": ".kg.tidb_impl",
126
- "TiDBVectorDBStorage": ".kg.tidb_impl",
127
- "TiDBGraphStorage": ".kg.tidb_impl",
128
  "PGKVStorage": ".kg.postgres_impl",
129
  "PGVectorStorage": ".kg.postgres_impl",
130
  "AGEStorage": ".kg.age_impl",
131
  "PGGraphStorage": ".kg.postgres_impl",
132
- "GremlinStorage": ".kg.gremlin_impl",
133
  "PGDocStatusStorage": ".kg.postgres_impl",
134
  "FaissVectorDBStorage": ".kg.faiss_impl",
135
  "QdrantVectorDBStorage": ".kg.qdrant_impl",
 
2
  "KV_STORAGE": {
3
  "implementations": [
4
  "JsonKVStorage",
 
5
  "RedisKVStorage",
 
6
  "PGKVStorage",
7
+ "MongoKVStorage",
8
+ # "TiDBKVStorage",
9
  ],
10
  "required_methods": ["get_by_id", "upsert"],
11
  },
 
13
  "implementations": [
14
  "NetworkXStorage",
15
  "Neo4JStorage",
 
 
 
 
16
  "PGGraphStorage",
17
+ # "AGEStorage",
18
+ # "MongoGraphStorage",
19
+ # "TiDBGraphStorage",
20
+ # "GremlinStorage",
21
  ],
22
  "required_methods": ["upsert_node", "upsert_edge"],
23
  },
 
26
  "NanoVectorDBStorage",
27
  "MilvusVectorDBStorage",
28
  "ChromaVectorDBStorage",
 
29
  "PGVectorStorage",
30
  "FaissVectorDBStorage",
31
  "QdrantVectorDBStorage",
 
32
  "MongoVectorDBStorage",
33
+ # "TiDBVectorDBStorage",
34
  ],
35
  "required_methods": ["query", "upsert"],
36
  },
 
38
  "implementations": [
39
  "JsonDocStatusStorage",
40
  "PGDocStatusStorage",
 
41
  "MongoDocStatusStorage",
42
  ],
43
  "required_methods": ["get_docs_by_status"],
 
50
  "JsonKVStorage": [],
51
  "MongoKVStorage": [],
52
  "RedisKVStorage": ["REDIS_URI"],
53
+ # "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
54
  "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
 
 
 
 
 
 
55
  # Graph Storage Implementations
56
  "NetworkXStorage": [],
57
  "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
58
  "MongoGraphStorage": [],
59
+ # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
60
  "AGEStorage": [
61
  "AGE_POSTGRES_DB",
62
  "AGE_POSTGRES_USER",
63
  "AGE_POSTGRES_PASSWORD",
64
  ],
65
+ # "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
66
  "PGGraphStorage": [
67
  "POSTGRES_USER",
68
  "POSTGRES_PASSWORD",
69
  "POSTGRES_DATABASE",
70
  ],
 
 
 
 
 
 
71
  # Vector Storage Implementations
72
  "NanoVectorDBStorage": [],
73
  "MilvusVectorDBStorage": [],
74
  "ChromaVectorDBStorage": [],
75
+ # "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
76
  "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
77
  "FaissVectorDBStorage": [],
78
  "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
 
 
 
 
 
 
79
  "MongoVectorDBStorage": [],
80
  # Document Status Storage Implementations
81
  "JsonDocStatusStorage": [],
 
90
  "NanoVectorDBStorage": ".kg.nano_vector_db_impl",
91
  "JsonDocStatusStorage": ".kg.json_doc_status_impl",
92
  "Neo4JStorage": ".kg.neo4j_impl",
 
 
 
93
  "MilvusVectorDBStorage": ".kg.milvus_impl",
94
  "MongoKVStorage": ".kg.mongo_impl",
95
  "MongoDocStatusStorage": ".kg.mongo_impl",
 
97
  "MongoVectorDBStorage": ".kg.mongo_impl",
98
  "RedisKVStorage": ".kg.redis_impl",
99
  "ChromaVectorDBStorage": ".kg.chroma_impl",
100
+ # "TiDBKVStorage": ".kg.tidb_impl",
101
+ # "TiDBVectorDBStorage": ".kg.tidb_impl",
102
+ # "TiDBGraphStorage": ".kg.tidb_impl",
103
  "PGKVStorage": ".kg.postgres_impl",
104
  "PGVectorStorage": ".kg.postgres_impl",
105
  "AGEStorage": ".kg.age_impl",
106
  "PGGraphStorage": ".kg.postgres_impl",
107
+ # "GremlinStorage": ".kg.gremlin_impl",
108
  "PGDocStatusStorage": ".kg.postgres_impl",
109
  "FaissVectorDBStorage": ".kg.faiss_impl",
110
  "QdrantVectorDBStorage": ".kg.qdrant_impl",
lightrag/kg/age_impl.py CHANGED
@@ -34,9 +34,9 @@ if not pm.is_installed("psycopg-pool"):
34
  if not pm.is_installed("asyncpg"):
35
  pm.install("asyncpg")
36
 
37
- import psycopg
38
- from psycopg.rows import namedtuple_row
39
- from psycopg_pool import AsyncConnectionPool, PoolTimeout
40
 
41
 
42
  class AGEQueryException(Exception):
@@ -871,3 +871,21 @@ class AGEStorage(BaseGraphStorage):
871
  async def index_done_callback(self) -> None:
872
  # AGES handles persistence automatically
873
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  if not pm.is_installed("asyncpg"):
35
  pm.install("asyncpg")
36
 
37
+ import psycopg # type: ignore
38
+ from psycopg.rows import namedtuple_row # type: ignore
39
+ from psycopg_pool import AsyncConnectionPool, PoolTimeout # type: ignore
40
 
41
 
42
  class AGEQueryException(Exception):
 
871
  async def index_done_callback(self) -> None:
872
  # AGES handles persistence automatically
873
  pass
874
+
875
+ async def drop(self) -> dict[str, str]:
876
+ """Drop the storage by removing all nodes and relationships in the graph.
877
+
878
+ Returns:
879
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
880
+ """
881
+ try:
882
+ query = """
883
+ MATCH (n)
884
+ DETACH DELETE n
885
+ """
886
+ await self._query(query)
887
+ logger.info(f"Successfully dropped all data from graph {self.graph_name}")
888
+ return {"status": "success", "message": "graph data dropped"}
889
+ except Exception as e:
890
+ logger.error(f"Error dropping graph {self.graph_name}: {e}")
891
+ return {"status": "error", "message": str(e)}
lightrag/kg/chroma_impl.py CHANGED
@@ -1,4 +1,5 @@
1
  import asyncio
 
2
  from dataclasses import dataclass
3
  from typing import Any, final
4
  import numpy as np
@@ -10,8 +11,8 @@ import pipmaster as pm
10
  if not pm.is_installed("chromadb"):
11
  pm.install("chromadb")
12
 
13
- from chromadb import HttpClient, PersistentClient
14
- from chromadb.config import Settings
15
 
16
 
17
  @final
@@ -335,3 +336,28 @@ class ChromaVectorDBStorage(BaseVectorStorage):
335
  except Exception as e:
336
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
337
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
+ import os
3
  from dataclasses import dataclass
4
  from typing import Any, final
5
  import numpy as np
 
11
  if not pm.is_installed("chromadb"):
12
  pm.install("chromadb")
13
 
14
+ from chromadb import HttpClient, PersistentClient # type: ignore
15
+ from chromadb.config import Settings # type: ignore
16
 
17
 
18
  @final
 
336
  except Exception as e:
337
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
338
  return []
339
+
340
+ async def drop(self) -> dict[str, str]:
341
+ """Drop all vector data from storage and clean up resources
342
+
343
+ This method will delete all documents from the ChromaDB collection.
344
+
345
+ Returns:
346
+ dict[str, str]: Operation status and message
347
+ - On success: {"status": "success", "message": "data dropped"}
348
+ - On failure: {"status": "error", "message": "<error details>"}
349
+ """
350
+ try:
351
+ # Get all IDs in the collection
352
+ result = self._collection.get(include=[])
353
+ if result and result["ids"] and len(result["ids"]) > 0:
354
+ # Delete all documents
355
+ self._collection.delete(ids=result["ids"])
356
+
357
+ logger.info(
358
+ f"Process {os.getpid()} drop ChromaDB collection {self.namespace}"
359
+ )
360
+ return {"status": "success", "message": "data dropped"}
361
+ except Exception as e:
362
+ logger.error(f"Error dropping ChromaDB collection {self.namespace}: {e}")
363
+ return {"status": "error", "message": str(e)}
lightrag/kg/faiss_impl.py CHANGED
@@ -11,16 +11,20 @@ import pipmaster as pm
11
  from lightrag.utils import logger, compute_mdhash_id
12
  from lightrag.base import BaseVectorStorage
13
 
14
- if not pm.is_installed("faiss"):
15
- pm.install("faiss")
16
-
17
- import faiss # type: ignore
18
  from .shared_storage import (
19
  get_storage_lock,
20
  get_update_flag,
21
  set_all_update_flags,
22
  )
23
 
 
 
 
 
 
 
 
 
24
 
25
  @final
26
  @dataclass
@@ -217,6 +221,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
217
  async def delete(self, ids: list[str]):
218
  """
219
  Delete vectors for the provided custom IDs.
 
 
 
 
 
220
  """
221
  logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
222
  to_remove = []
@@ -232,13 +241,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
232
  )
233
 
234
  async def delete_entity(self, entity_name: str) -> None:
 
 
 
 
 
 
235
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
236
  logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
237
  await self.delete([entity_id])
238
 
239
  async def delete_entity_relation(self, entity_name: str) -> None:
240
  """
241
- Delete relations for a given entity by scanning metadata.
 
 
 
242
  """
243
  logger.debug(f"Searching relations for entity {entity_name}")
244
  relations = []
@@ -429,3 +447,44 @@ class FaissVectorDBStorage(BaseVectorStorage):
429
  results.append({**metadata, "id": metadata.get("__id__")})
430
 
431
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from lightrag.utils import logger, compute_mdhash_id
12
  from lightrag.base import BaseVectorStorage
13
 
 
 
 
 
14
  from .shared_storage import (
15
  get_storage_lock,
16
  get_update_flag,
17
  set_all_update_flags,
18
  )
19
 
20
+ import faiss # type: ignore
21
+
22
+ USE_GPU = os.getenv("FAISS_USE_GPU", "0") == "1"
23
+ FAISS_PACKAGE = "faiss-gpu" if USE_GPU else "faiss-cpu"
24
+
25
+ if not pm.is_installed(FAISS_PACKAGE):
26
+ pm.install(FAISS_PACKAGE)
27
+
28
 
29
  @final
30
  @dataclass
 
221
  async def delete(self, ids: list[str]):
222
  """
223
  Delete vectors for the provided custom IDs.
224
+
225
+ Importance notes:
226
+ 1. Changes will be persisted to disk during the next index_done_callback
227
+ 2. Only one process should updating the storage at a time before index_done_callback,
228
+ KG-storage-log should be used to avoid data corruption
229
  """
230
  logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
231
  to_remove = []
 
241
  )
242
 
243
  async def delete_entity(self, entity_name: str) -> None:
244
+ """
245
+ Importance notes:
246
+ 1. Changes will be persisted to disk during the next index_done_callback
247
+ 2. Only one process should updating the storage at a time before index_done_callback,
248
+ KG-storage-log should be used to avoid data corruption
249
+ """
250
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
251
  logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
252
  await self.delete([entity_id])
253
 
254
  async def delete_entity_relation(self, entity_name: str) -> None:
255
  """
256
+ Importance notes:
257
+ 1. Changes will be persisted to disk during the next index_done_callback
258
+ 2. Only one process should updating the storage at a time before index_done_callback,
259
+ KG-storage-log should be used to avoid data corruption
260
  """
261
  logger.debug(f"Searching relations for entity {entity_name}")
262
  relations = []
 
447
  results.append({**metadata, "id": metadata.get("__id__")})
448
 
449
  return results
450
+
451
+ async def drop(self) -> dict[str, str]:
452
+ """Drop all vector data from storage and clean up resources
453
+
454
+ This method will:
455
+ 1. Remove the vector database storage file if it exists
456
+ 2. Reinitialize the vector database client
457
+ 3. Update flags to notify other processes
458
+ 4. Changes is persisted to disk immediately
459
+
460
+ This method will remove all vectors from the Faiss index and delete the storage files.
461
+
462
+ Returns:
463
+ dict[str, str]: Operation status and message
464
+ - On success: {"status": "success", "message": "data dropped"}
465
+ - On failure: {"status": "error", "message": "<error details>"}
466
+ """
467
+ try:
468
+ async with self._storage_lock:
469
+ # Reset the index
470
+ self._index = faiss.IndexFlatIP(self._dim)
471
+ self._id_to_meta = {}
472
+
473
+ # Remove storage files if they exist
474
+ if os.path.exists(self._faiss_index_file):
475
+ os.remove(self._faiss_index_file)
476
+ if os.path.exists(self._meta_file):
477
+ os.remove(self._meta_file)
478
+
479
+ self._id_to_meta = {}
480
+ self._load_faiss_index()
481
+
482
+ # Notify other processes
483
+ await set_all_update_flags(self.namespace)
484
+ self.storage_updated.value = False
485
+
486
+ logger.info(f"Process {os.getpid()} drop FAISS index {self.namespace}")
487
+ return {"status": "success", "message": "data dropped"}
488
+ except Exception as e:
489
+ logger.error(f"Error dropping FAISS index {self.namespace}: {e}")
490
+ return {"status": "error", "message": str(e)}
lightrag/kg/gremlin_impl.py CHANGED
@@ -24,9 +24,9 @@ from ..base import BaseGraphStorage
24
  if not pm.is_installed("gremlinpython"):
25
  pm.install("gremlinpython")
26
 
27
- from gremlin_python.driver import client, serializer
28
- from gremlin_python.driver.aiohttp.transport import AiohttpTransport
29
- from gremlin_python.driver.protocol import GremlinServerError
30
 
31
 
32
  @final
@@ -695,3 +695,24 @@ class GremlinStorage(BaseGraphStorage):
695
  except Exception as e:
696
  logger.error(f"Error during edge deletion: {str(e)}")
697
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  if not pm.is_installed("gremlinpython"):
25
  pm.install("gremlinpython")
26
 
27
+ from gremlin_python.driver import client, serializer # type: ignore
28
+ from gremlin_python.driver.aiohttp.transport import AiohttpTransport # type: ignore
29
+ from gremlin_python.driver.protocol import GremlinServerError # type: ignore
30
 
31
 
32
  @final
 
695
  except Exception as e:
696
  logger.error(f"Error during edge deletion: {str(e)}")
697
  raise
698
+
699
+ async def drop(self) -> dict[str, str]:
700
+ """Drop the storage by removing all nodes and relationships in the graph.
701
+
702
+ This function deletes all nodes with the specified graph name property,
703
+ which automatically removes all associated edges.
704
+
705
+ Returns:
706
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
707
+ """
708
+ try:
709
+ query = f"""g
710
+ .V().has('graph', {self.graph_name})
711
+ .drop()
712
+ """
713
+ await self._query(query)
714
+ logger.info(f"Successfully dropped all data from graph {self.graph_name}")
715
+ return {"status": "success", "message": "graph data dropped"}
716
+ except Exception as e:
717
+ logger.error(f"Error dropping graph {self.graph_name}: {e}")
718
+ return {"status": "error", "message": str(e)}
lightrag/kg/json_doc_status_impl.py CHANGED
@@ -109,6 +109,11 @@ class JsonDocStatusStorage(DocStatusStorage):
109
  await clear_all_update_flags(self.namespace)
110
 
111
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
 
 
 
 
112
  if not data:
113
  return
114
  logger.info(f"Inserting {len(data)} records to {self.namespace}")
@@ -122,16 +127,50 @@ class JsonDocStatusStorage(DocStatusStorage):
122
  async with self._storage_lock:
123
  return self._data.get(id)
124
 
125
- async def delete(self, doc_ids: list[str]):
 
 
 
 
 
 
 
 
 
 
 
 
126
  async with self._storage_lock:
 
127
  for doc_id in doc_ids:
128
- self._data.pop(doc_id, None)
129
- await set_all_update_flags(self.namespace)
130
- await self.index_done_callback()
131
 
132
- async def drop(self) -> None:
133
- """Drop the storage"""
134
- async with self._storage_lock:
135
- self._data.clear()
136
- await set_all_update_flags(self.namespace)
137
- await self.index_done_callback()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  await clear_all_update_flags(self.namespace)
110
 
111
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
112
+ """
113
+ Importance notes for in-memory storage:
114
+ 1. Changes will be persisted to disk during the next index_done_callback
115
+ 2. update flags to notify other processes that data persistence is needed
116
+ """
117
  if not data:
118
  return
119
  logger.info(f"Inserting {len(data)} records to {self.namespace}")
 
127
  async with self._storage_lock:
128
  return self._data.get(id)
129
 
130
+ async def delete(self, doc_ids: list[str]) -> None:
131
+ """Delete specific records from storage by their IDs
132
+
133
+ Importance notes for in-memory storage:
134
+ 1. Changes will be persisted to disk during the next index_done_callback
135
+ 2. update flags to notify other processes that data persistence is needed
136
+
137
+ Args:
138
+ ids (list[str]): List of document IDs to be deleted from storage
139
+
140
+ Returns:
141
+ None
142
+ """
143
  async with self._storage_lock:
144
+ any_deleted = False
145
  for doc_id in doc_ids:
146
+ result = self._data.pop(doc_id, None)
147
+ if result is not None:
148
+ any_deleted = True
149
 
150
+ if any_deleted:
151
+ await set_all_update_flags(self.namespace)
152
+
153
+ async def drop(self) -> dict[str, str]:
154
+ """Drop all document status data from storage and clean up resources
155
+
156
+ This method will:
157
+ 1. Clear all document status data from memory
158
+ 2. Update flags to notify other processes
159
+ 3. Trigger index_done_callback to save the empty state
160
+
161
+ Returns:
162
+ dict[str, str]: Operation status and message
163
+ - On success: {"status": "success", "message": "data dropped"}
164
+ - On failure: {"status": "error", "message": "<error details>"}
165
+ """
166
+ try:
167
+ async with self._storage_lock:
168
+ self._data.clear()
169
+ await set_all_update_flags(self.namespace)
170
+
171
+ await self.index_done_callback()
172
+ logger.info(f"Process {os.getpid()} drop {self.namespace}")
173
+ return {"status": "success", "message": "data dropped"}
174
+ except Exception as e:
175
+ logger.error(f"Error dropping {self.namespace}: {e}")
176
+ return {"status": "error", "message": str(e)}
lightrag/kg/json_kv_impl.py CHANGED
@@ -114,6 +114,11 @@ class JsonKVStorage(BaseKVStorage):
114
  return set(keys) - set(self._data.keys())
115
 
116
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
 
 
 
 
117
  if not data:
118
  return
119
  logger.info(f"Inserting {len(data)} records to {self.namespace}")
@@ -122,8 +127,73 @@ class JsonKVStorage(BaseKVStorage):
122
  await set_all_update_flags(self.namespace)
123
 
124
  async def delete(self, ids: list[str]) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
125
  async with self._storage_lock:
 
126
  for doc_id in ids:
127
- self._data.pop(doc_id, None)
128
- await set_all_update_flags(self.namespace)
129
- await self.index_done_callback()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  return set(keys) - set(self._data.keys())
115
 
116
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
117
+ """
118
+ Importance notes for in-memory storage:
119
+ 1. Changes will be persisted to disk during the next index_done_callback
120
+ 2. update flags to notify other processes that data persistence is needed
121
+ """
122
  if not data:
123
  return
124
  logger.info(f"Inserting {len(data)} records to {self.namespace}")
 
127
  await set_all_update_flags(self.namespace)
128
 
129
  async def delete(self, ids: list[str]) -> None:
130
+ """Delete specific records from storage by their IDs
131
+
132
+ Importance notes for in-memory storage:
133
+ 1. Changes will be persisted to disk during the next index_done_callback
134
+ 2. update flags to notify other processes that data persistence is needed
135
+
136
+ Args:
137
+ ids (list[str]): List of document IDs to be deleted from storage
138
+
139
+ Returns:
140
+ None
141
+ """
142
  async with self._storage_lock:
143
+ any_deleted = False
144
  for doc_id in ids:
145
+ result = self._data.pop(doc_id, None)
146
+ if result is not None:
147
+ any_deleted = True
148
+
149
+ if any_deleted:
150
+ await set_all_update_flags(self.namespace)
151
+
152
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
153
+ """Delete specific records from storage by by cache mode
154
+
155
+ Importance notes for in-memory storage:
156
+ 1. Changes will be persisted to disk during the next index_done_callback
157
+ 2. update flags to notify other processes that data persistence is needed
158
+
159
+ Args:
160
+ ids (list[str]): List of cache mode to be drop from storage
161
+
162
+ Returns:
163
+ True: if the cache drop successfully
164
+ False: if the cache drop failed
165
+ """
166
+ if not modes:
167
+ return False
168
+
169
+ try:
170
+ await self.delete(modes)
171
+ return True
172
+ except Exception:
173
+ return False
174
+
175
+ async def drop(self) -> dict[str, str]:
176
+ """Drop all data from storage and clean up resources
177
+ This action will persistent the data to disk immediately.
178
+
179
+ This method will:
180
+ 1. Clear all data from memory
181
+ 2. Update flags to notify other processes
182
+ 3. Trigger index_done_callback to save the empty state
183
+
184
+ Returns:
185
+ dict[str, str]: Operation status and message
186
+ - On success: {"status": "success", "message": "data dropped"}
187
+ - On failure: {"status": "error", "message": "<error details>"}
188
+ """
189
+ try:
190
+ async with self._storage_lock:
191
+ self._data.clear()
192
+ await set_all_update_flags(self.namespace)
193
+
194
+ await self.index_done_callback()
195
+ logger.info(f"Process {os.getpid()} drop {self.namespace}")
196
+ return {"status": "success", "message": "data dropped"}
197
+ except Exception as e:
198
+ logger.error(f"Error dropping {self.namespace}: {e}")
199
+ return {"status": "error", "message": str(e)}
lightrag/kg/milvus_impl.py CHANGED
@@ -15,7 +15,7 @@ if not pm.is_installed("pymilvus"):
15
  pm.install("pymilvus")
16
 
17
  import configparser
18
- from pymilvus import MilvusClient
19
 
20
  config = configparser.ConfigParser()
21
  config.read("config.ini", "utf-8")
@@ -287,3 +287,33 @@ class MilvusVectorDBStorage(BaseVectorStorage):
287
  except Exception as e:
288
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
289
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  pm.install("pymilvus")
16
 
17
  import configparser
18
+ from pymilvus import MilvusClient # type: ignore
19
 
20
  config = configparser.ConfigParser()
21
  config.read("config.ini", "utf-8")
 
287
  except Exception as e:
288
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
289
  return []
290
+
291
+ async def drop(self) -> dict[str, str]:
292
+ """Drop all vector data from storage and clean up resources
293
+
294
+ This method will delete all data from the Milvus collection.
295
+
296
+ Returns:
297
+ dict[str, str]: Operation status and message
298
+ - On success: {"status": "success", "message": "data dropped"}
299
+ - On failure: {"status": "error", "message": "<error details>"}
300
+ """
301
+ try:
302
+ # Drop the collection and recreate it
303
+ if self._client.has_collection(self.namespace):
304
+ self._client.drop_collection(self.namespace)
305
+
306
+ # Recreate the collection
307
+ MilvusVectorDBStorage.create_collection_if_not_exist(
308
+ self._client,
309
+ self.namespace,
310
+ dimension=self.embedding_func.embedding_dim,
311
+ )
312
+
313
+ logger.info(
314
+ f"Process {os.getpid()} drop Milvus collection {self.namespace}"
315
+ )
316
+ return {"status": "success", "message": "data dropped"}
317
+ except Exception as e:
318
+ logger.error(f"Error dropping Milvus collection {self.namespace}: {e}")
319
+ return {"status": "error", "message": str(e)}
lightrag/kg/mongo_impl.py CHANGED
@@ -25,13 +25,13 @@ if not pm.is_installed("pymongo"):
25
  if not pm.is_installed("motor"):
26
  pm.install("motor")
27
 
28
- from motor.motor_asyncio import (
29
  AsyncIOMotorClient,
30
  AsyncIOMotorDatabase,
31
  AsyncIOMotorCollection,
32
  )
33
- from pymongo.operations import SearchIndexModel
34
- from pymongo.errors import PyMongoError
35
 
36
  config = configparser.ConfigParser()
37
  config.read("config.ini", "utf-8")
@@ -150,6 +150,66 @@ class MongoKVStorage(BaseKVStorage):
150
  # Mongo handles persistence automatically
151
  pass
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  @final
155
  @dataclass
@@ -230,6 +290,27 @@ class MongoDocStatusStorage(DocStatusStorage):
230
  # Mongo handles persistence automatically
231
  pass
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  @final
235
  @dataclass
@@ -840,6 +921,27 @@ class MongoGraphStorage(BaseGraphStorage):
840
 
841
  logger.debug(f"Successfully deleted edges: {edges}")
842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
843
 
844
  @final
845
  @dataclass
@@ -1127,6 +1229,31 @@ class MongoVectorDBStorage(BaseVectorStorage):
1127
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
1128
  return []
1129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1130
 
1131
  async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
1132
  collection_names = await db.list_collection_names()
 
25
  if not pm.is_installed("motor"):
26
  pm.install("motor")
27
 
28
+ from motor.motor_asyncio import ( # type: ignore
29
  AsyncIOMotorClient,
30
  AsyncIOMotorDatabase,
31
  AsyncIOMotorCollection,
32
  )
33
+ from pymongo.operations import SearchIndexModel # type: ignore
34
+ from pymongo.errors import PyMongoError # type: ignore
35
 
36
  config = configparser.ConfigParser()
37
  config.read("config.ini", "utf-8")
 
150
  # Mongo handles persistence automatically
151
  pass
152
 
153
+ async def delete(self, ids: list[str]) -> None:
154
+ """Delete documents with specified IDs
155
+
156
+ Args:
157
+ ids: List of document IDs to be deleted
158
+ """
159
+ if not ids:
160
+ return
161
+
162
+ try:
163
+ result = await self._data.delete_many({"_id": {"$in": ids}})
164
+ logger.info(
165
+ f"Deleted {result.deleted_count} documents from {self.namespace}"
166
+ )
167
+ except PyMongoError as e:
168
+ logger.error(f"Error deleting documents from {self.namespace}: {e}")
169
+
170
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
171
+ """Delete specific records from storage by cache mode
172
+
173
+ Args:
174
+ modes (list[str]): List of cache modes to be dropped from storage
175
+
176
+ Returns:
177
+ bool: True if successful, False otherwise
178
+ """
179
+ if not modes:
180
+ return False
181
+
182
+ try:
183
+ # Build regex pattern to match documents with the specified modes
184
+ pattern = f"^({'|'.join(modes)})_"
185
+ result = await self._data.delete_many({"_id": {"$regex": pattern}})
186
+ logger.info(f"Deleted {result.deleted_count} documents by modes: {modes}")
187
+ return True
188
+ except Exception as e:
189
+ logger.error(f"Error deleting cache by modes {modes}: {e}")
190
+ return False
191
+
192
+ async def drop(self) -> dict[str, str]:
193
+ """Drop the storage by removing all documents in the collection.
194
+
195
+ Returns:
196
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
197
+ """
198
+ try:
199
+ result = await self._data.delete_many({})
200
+ deleted_count = result.deleted_count
201
+
202
+ logger.info(
203
+ f"Dropped {deleted_count} documents from doc status {self._collection_name}"
204
+ )
205
+ return {
206
+ "status": "success",
207
+ "message": f"{deleted_count} documents dropped",
208
+ }
209
+ except PyMongoError as e:
210
+ logger.error(f"Error dropping doc status {self._collection_name}: {e}")
211
+ return {"status": "error", "message": str(e)}
212
+
213
 
214
  @final
215
  @dataclass
 
290
  # Mongo handles persistence automatically
291
  pass
292
 
293
+ async def drop(self) -> dict[str, str]:
294
+ """Drop the storage by removing all documents in the collection.
295
+
296
+ Returns:
297
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
298
+ """
299
+ try:
300
+ result = await self._data.delete_many({})
301
+ deleted_count = result.deleted_count
302
+
303
+ logger.info(
304
+ f"Dropped {deleted_count} documents from doc status {self._collection_name}"
305
+ )
306
+ return {
307
+ "status": "success",
308
+ "message": f"{deleted_count} documents dropped",
309
+ }
310
+ except PyMongoError as e:
311
+ logger.error(f"Error dropping doc status {self._collection_name}: {e}")
312
+ return {"status": "error", "message": str(e)}
313
+
314
 
315
  @final
316
  @dataclass
 
921
 
922
  logger.debug(f"Successfully deleted edges: {edges}")
923
 
924
+ async def drop(self) -> dict[str, str]:
925
+ """Drop the storage by removing all documents in the collection.
926
+
927
+ Returns:
928
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
929
+ """
930
+ try:
931
+ result = await self.collection.delete_many({})
932
+ deleted_count = result.deleted_count
933
+
934
+ logger.info(
935
+ f"Dropped {deleted_count} documents from graph {self._collection_name}"
936
+ )
937
+ return {
938
+ "status": "success",
939
+ "message": f"{deleted_count} documents dropped",
940
+ }
941
+ except PyMongoError as e:
942
+ logger.error(f"Error dropping graph {self._collection_name}: {e}")
943
+ return {"status": "error", "message": str(e)}
944
+
945
 
946
  @final
947
  @dataclass
 
1229
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
1230
  return []
1231
 
1232
+ async def drop(self) -> dict[str, str]:
1233
+ """Drop the storage by removing all documents in the collection and recreating vector index.
1234
+
1235
+ Returns:
1236
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
1237
+ """
1238
+ try:
1239
+ # Delete all documents
1240
+ result = await self._data.delete_many({})
1241
+ deleted_count = result.deleted_count
1242
+
1243
+ # Recreate vector index
1244
+ await self.create_vector_index_if_not_exists()
1245
+
1246
+ logger.info(
1247
+ f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
1248
+ )
1249
+ return {
1250
+ "status": "success",
1251
+ "message": f"{deleted_count} documents dropped and vector index recreated",
1252
+ }
1253
+ except PyMongoError as e:
1254
+ logger.error(f"Error dropping vector storage {self._collection_name}: {e}")
1255
+ return {"status": "error", "message": str(e)}
1256
+
1257
 
1258
  async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
1259
  collection_names = await db.list_collection_names()
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -78,6 +78,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
78
  return self._client
79
 
80
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
 
 
 
 
 
 
81
  logger.info(f"Inserting {len(data)} to {self.namespace}")
82
  if not data:
83
  return
@@ -146,6 +153,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
146
  async def delete(self, ids: list[str]):
147
  """Delete vectors with specified IDs
148
 
 
 
 
 
 
149
  Args:
150
  ids: List of vector IDs to be deleted
151
  """
@@ -159,6 +171,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
159
  logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
160
 
161
  async def delete_entity(self, entity_name: str) -> None:
 
 
 
 
 
 
 
162
  try:
163
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
164
  logger.debug(
@@ -176,6 +195,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
176
  logger.error(f"Error deleting entity {entity_name}: {e}")
177
 
178
  async def delete_entity_relation(self, entity_name: str) -> None:
 
 
 
 
 
 
 
179
  try:
180
  client = await self._get_client()
181
  storage = getattr(client, "_NanoVectorDB__storage")
@@ -280,3 +306,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
280
 
281
  client = await self._get_client()
282
  return client.get(ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return self._client
79
 
80
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
81
+ """
82
+ Importance notes:
83
+ 1. Changes will be persisted to disk during the next index_done_callback
84
+ 2. Only one process should updating the storage at a time before index_done_callback,
85
+ KG-storage-log should be used to avoid data corruption
86
+ """
87
+
88
  logger.info(f"Inserting {len(data)} to {self.namespace}")
89
  if not data:
90
  return
 
153
  async def delete(self, ids: list[str]):
154
  """Delete vectors with specified IDs
155
 
156
+ Importance notes:
157
+ 1. Changes will be persisted to disk during the next index_done_callback
158
+ 2. Only one process should updating the storage at a time before index_done_callback,
159
+ KG-storage-log should be used to avoid data corruption
160
+
161
  Args:
162
  ids: List of vector IDs to be deleted
163
  """
 
171
  logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
172
 
173
  async def delete_entity(self, entity_name: str) -> None:
174
+ """
175
+ Importance notes:
176
+ 1. Changes will be persisted to disk during the next index_done_callback
177
+ 2. Only one process should updating the storage at a time before index_done_callback,
178
+ KG-storage-log should be used to avoid data corruption
179
+ """
180
+
181
  try:
182
  entity_id = compute_mdhash_id(entity_name, prefix="ent-")
183
  logger.debug(
 
195
  logger.error(f"Error deleting entity {entity_name}: {e}")
196
 
197
  async def delete_entity_relation(self, entity_name: str) -> None:
198
+ """
199
+ Importance notes:
200
+ 1. Changes will be persisted to disk during the next index_done_callback
201
+ 2. Only one process should updating the storage at a time before index_done_callback,
202
+ KG-storage-log should be used to avoid data corruption
203
+ """
204
+
205
  try:
206
  client = await self._get_client()
207
  storage = getattr(client, "_NanoVectorDB__storage")
 
306
 
307
  client = await self._get_client()
308
  return client.get(ids)
309
+
310
+ async def drop(self) -> dict[str, str]:
311
+ """Drop all vector data from storage and clean up resources
312
+
313
+ This method will:
314
+ 1. Remove the vector database storage file if it exists
315
+ 2. Reinitialize the vector database client
316
+ 3. Update flags to notify other processes
317
+ 4. Changes is persisted to disk immediately
318
+
319
+ This method is intended for use in scenarios where all data needs to be removed,
320
+
321
+ Returns:
322
+ dict[str, str]: Operation status and message
323
+ - On success: {"status": "success", "message": "data dropped"}
324
+ - On failure: {"status": "error", "message": "<error details>"}
325
+ """
326
+ try:
327
+ async with self._storage_lock:
328
+ # delete _client_file_name
329
+ if os.path.exists(self._client_file_name):
330
+ os.remove(self._client_file_name)
331
+
332
+ self._client = NanoVectorDB(
333
+ self.embedding_func.embedding_dim,
334
+ storage_file=self._client_file_name,
335
+ )
336
+
337
+ # Notify other processes that data has been updated
338
+ await set_all_update_flags(self.namespace)
339
+ # Reset own update flag to avoid self-reloading
340
+ self.storage_updated.value = False
341
+
342
+ logger.info(
343
+ f"Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
344
+ )
345
+ return {"status": "success", "message": "data dropped"}
346
+ except Exception as e:
347
+ logger.error(f"Error dropping {self.namespace}: {e}")
348
+ return {"status": "error", "message": str(e)}
lightrag/kg/neo4j_impl.py CHANGED
@@ -1,9 +1,8 @@
1
- import asyncio
2
  import inspect
3
  import os
4
  import re
5
  from dataclasses import dataclass
6
- from typing import Any, final, Optional
7
  import numpy as np
8
  import configparser
9
 
@@ -29,7 +28,6 @@ from neo4j import ( # type: ignore
29
  exceptions as neo4jExceptions,
30
  AsyncDriver,
31
  AsyncManagedTransaction,
32
- GraphDatabase,
33
  )
34
 
35
  config = configparser.ConfigParser()
@@ -52,8 +50,13 @@ class Neo4JStorage(BaseGraphStorage):
52
  embedding_func=embedding_func,
53
  )
54
  self._driver = None
55
- self._driver_lock = asyncio.Lock()
56
 
 
 
 
 
 
 
57
  URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
58
  USERNAME = os.environ.get(
59
  "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
@@ -86,7 +89,7 @@ class Neo4JStorage(BaseGraphStorage):
86
  ),
87
  )
88
  DATABASE = os.environ.get(
89
- "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
90
  )
91
 
92
  self._driver: AsyncDriver = AsyncGraphDatabase.driver(
@@ -98,71 +101,92 @@ class Neo4JStorage(BaseGraphStorage):
98
  max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
99
  )
100
 
101
- # Try to connect to the database
102
- with GraphDatabase.driver(
103
- URI,
104
- auth=(USERNAME, PASSWORD),
105
- max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
106
- connection_timeout=CONNECTION_TIMEOUT,
107
- connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
108
- ) as _sync_driver:
109
- for database in (DATABASE, None):
110
- self._DATABASE = database
111
- connected = False
112
 
113
- try:
114
- with _sync_driver.session(database=database) as session:
115
- try:
116
- session.run("MATCH (n) RETURN n LIMIT 0")
117
- logger.info(f"Connected to {database} at {URI}")
118
- connected = True
119
- except neo4jExceptions.ServiceUnavailable as e:
120
- logger.error(
121
- f"{database} at {URI} is not available".capitalize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  raise e
124
- except neo4jExceptions.AuthError as e:
125
- logger.error(f"Authentication failed for {database} at {URI}")
126
- raise e
127
- except neo4jExceptions.ClientError as e:
128
- if e.code == "Neo.ClientError.Database.DatabaseNotFound":
129
- logger.info(
130
- f"{database} at {URI} not found. Try to create specified database.".capitalize()
131
- )
 
 
 
132
  try:
133
- with _sync_driver.session() as session:
134
- session.run(
135
- f"CREATE DATABASE `{database}` IF NOT EXISTS"
136
- )
137
- logger.info(f"{database} at {URI} created".capitalize())
138
- connected = True
139
- except (
140
- neo4jExceptions.ClientError,
141
- neo4jExceptions.DatabaseError,
142
- ) as e:
143
- if (
144
- e.code
145
- == "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
146
- ) or (
147
- e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
148
- ):
149
- if database is not None:
150
- logger.warning(
151
- "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
152
- )
153
- if database is None:
154
- logger.error(f"Failed to create {database} at {URI}")
155
- raise e
156
 
157
- if connected:
158
- break
159
 
160
- def __post_init__(self):
161
- self._node_embed_algorithms = {
162
- "node2vec": self._node2vec_embed,
163
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- async def close(self):
166
  """Close the Neo4j driver and release all resources"""
167
  if self._driver:
168
  await self._driver.close()
@@ -170,7 +194,7 @@ class Neo4JStorage(BaseGraphStorage):
170
 
171
  async def __aexit__(self, exc_type, exc, tb):
172
  """Ensure driver is closed when context manager exits"""
173
- await self.close()
174
 
175
  async def index_done_callback(self) -> None:
176
  # Noe4J handles persistence automatically
@@ -243,7 +267,7 @@ class Neo4JStorage(BaseGraphStorage):
243
  raise
244
 
245
  async def get_node(self, node_id: str) -> dict[str, str] | None:
246
- """Get node by its label identifier.
247
 
248
  Args:
249
  node_id: The node label to look up
@@ -428,13 +452,8 @@ class Neo4JStorage(BaseGraphStorage):
428
  logger.debug(
429
  f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
430
  )
431
- # Return default edge properties when no edge found
432
- return {
433
- "weight": 0.0,
434
- "source_id": None,
435
- "description": None,
436
- "keywords": None,
437
- }
438
  finally:
439
  await result.consume() # Ensure result is fully consumed
440
 
@@ -526,7 +545,6 @@ class Neo4JStorage(BaseGraphStorage):
526
  """
527
  properties = node_data
528
  entity_type = properties["entity_type"]
529
- entity_id = properties["entity_id"]
530
  if "entity_id" not in properties:
531
  raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
532
 
@@ -536,15 +554,17 @@ class Neo4JStorage(BaseGraphStorage):
536
  async def execute_upsert(tx: AsyncManagedTransaction):
537
  query = (
538
  """
539
- MERGE (n:base {entity_id: $properties.entity_id})
540
  SET n += $properties
541
  SET n:`%s`
542
  """
543
  % entity_type
544
  )
545
- result = await tx.run(query, properties=properties)
 
 
546
  logger.debug(
547
- f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
548
  )
549
  await result.consume() # Ensure result is fully consumed
550
 
@@ -622,25 +642,19 @@ class Neo4JStorage(BaseGraphStorage):
622
  self,
623
  node_label: str,
624
  max_depth: int = 3,
625
- min_degree: int = 0,
626
- inclusive: bool = False,
627
  ) -> KnowledgeGraph:
628
  """
629
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
630
- Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
631
- When reducing the number of nodes, the prioritization criteria are as follows:
632
- 1. min_degree does not affect nodes directly connected to the matching nodes
633
- 2. Label matching nodes take precedence
634
- 3. Followed by nodes directly connected to the matching nodes
635
- 4. Finally, the degree of the nodes
636
 
637
  Args:
638
- node_label: Label of the starting node
639
- max_depth: Maximum depth of the subgraph
640
- min_degree: Minimum degree of nodes to include. Defaults to 0
641
- inclusive: Do an inclusive search if true
642
  Returns:
643
- KnowledgeGraph: Complete connected subgraph for specified node
 
644
  """
645
  result = KnowledgeGraph()
646
  seen_nodes = set()
@@ -651,11 +665,27 @@ class Neo4JStorage(BaseGraphStorage):
651
  ) as session:
652
  try:
653
  if node_label == "*":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  main_query = """
655
  MATCH (n)
656
  OPTIONAL MATCH (n)-[r]-()
657
  WITH n, COALESCE(count(r), 0) AS degree
658
- WHERE degree >= $min_degree
659
  ORDER BY degree DESC
660
  LIMIT $max_nodes
661
  WITH collect({node: n}) AS filtered_nodes
@@ -666,20 +696,23 @@ class Neo4JStorage(BaseGraphStorage):
666
  RETURN filtered_nodes AS node_info,
667
  collect(DISTINCT r) AS relationships
668
  """
669
- result_set = await session.run(
670
- main_query,
671
- {"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree},
672
- )
 
 
 
 
 
 
673
 
674
  else:
675
- # Main query uses partial matching
676
- main_query = """
 
677
  MATCH (start)
678
- WHERE
679
- CASE
680
- WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
681
- ELSE start.entity_id = $entity_id
682
- END
683
  WITH start
684
  CALL apoc.path.subgraphAll(start, {
685
  relationshipFilter: '',
@@ -688,78 +721,115 @@ class Neo4JStorage(BaseGraphStorage):
688
  bfs: true
689
  })
690
  YIELD nodes, relationships
691
- WITH start, nodes, relationships
692
  UNWIND nodes AS node
693
- OPTIONAL MATCH (node)-[r]-()
694
- WITH node, COALESCE(count(r), 0) AS degree, start, nodes, relationships
695
- WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
696
- ORDER BY
697
- CASE
698
- WHEN node = start THEN 3
699
- WHEN EXISTS((start)--(node)) THEN 2
700
- ELSE 1
701
- END DESC,
702
- degree DESC
703
- LIMIT $max_nodes
704
- WITH collect({node: node}) AS filtered_nodes
705
- UNWIND filtered_nodes AS node_info
706
- WITH collect(node_info.node) AS kept_nodes, filtered_nodes
707
- OPTIONAL MATCH (a)-[r]-(b)
708
- WHERE a IN kept_nodes AND b IN kept_nodes
709
- RETURN filtered_nodes AS node_info,
710
- collect(DISTINCT r) AS relationships
711
  """
712
- result_set = await session.run(
713
- main_query,
714
- {
715
- "max_nodes": MAX_GRAPH_NODES,
716
- "entity_id": node_label,
717
- "inclusive": inclusive,
718
- "max_depth": max_depth,
719
- "min_degree": min_degree,
720
- },
721
- )
722
 
723
- try:
724
- record = await result_set.single()
725
-
726
- if record:
727
- # Handle nodes (compatible with multi-label cases)
728
- for node_info in record["node_info"]:
729
- node = node_info["node"]
730
- node_id = node.id
731
- if node_id not in seen_nodes:
732
- result.nodes.append(
733
- KnowledgeGraphNode(
734
- id=f"{node_id}",
735
- labels=[node.get("entity_id")],
736
- properties=dict(node),
737
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
  )
739
- seen_nodes.add(node_id)
740
-
741
- # Handle relationships (including direction information)
742
- for rel in record["relationships"]:
743
- edge_id = rel.id
744
- if edge_id not in seen_edges:
745
- start = rel.start_node
746
- end = rel.end_node
747
- result.edges.append(
748
- KnowledgeGraphEdge(
749
- id=f"{edge_id}",
750
- type=rel.type,
751
- source=f"{start.id}",
752
- target=f"{end.id}",
753
- properties=dict(rel),
754
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
  )
756
- seen_edges.add(edge_id)
 
757
 
758
- logger.info(
759
- f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
760
- )
761
- finally:
762
- await result_set.consume() # Ensure result set is consumed
763
 
764
  except neo4jExceptions.ClientError as e:
765
  logger.warning(f"APOC plugin error: {str(e)}")
@@ -767,46 +837,89 @@ class Neo4JStorage(BaseGraphStorage):
767
  logger.warning(
768
  "Neo4j: falling back to basic Cypher recursive search..."
769
  )
770
- if inclusive:
771
- logger.warning(
772
- "Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
773
- )
774
- return await self._robust_fallback(
775
- node_label, max_depth, min_degree
776
  )
777
 
778
  return result
779
 
780
  async def _robust_fallback(
781
- self, node_label: str, max_depth: int, min_degree: int = 0
782
  ) -> KnowledgeGraph:
783
  """
784
  Fallback implementation when APOC plugin is not available or incompatible.
785
  This method implements the same functionality as get_knowledge_graph but uses
786
- only basic Cypher queries and recursive traversal instead of APOC procedures.
787
  """
 
 
788
  result = KnowledgeGraph()
789
  visited_nodes = set()
790
  visited_edges = set()
 
791
 
792
- async def traverse(
793
- node: KnowledgeGraphNode,
794
- edge: Optional[KnowledgeGraphEdge],
795
- current_depth: int,
796
- ):
797
- # Check traversal limits
798
- if current_depth > max_depth:
799
- logger.debug(f"Reached max depth: {max_depth}")
800
- return
801
- if len(visited_nodes) >= MAX_GRAPH_NODES:
802
- logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
803
- return
 
 
 
 
 
 
 
 
 
 
804
 
805
- # Check if node already visited
806
- if node.id in visited_nodes:
807
- return
808
 
809
- # Get all edges and target nodes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
  async with self._driver.session(
811
  database=self._DATABASE, default_access_mode="READ"
812
  ) as session:
@@ -815,32 +928,17 @@ class Neo4JStorage(BaseGraphStorage):
815
  WITH r, b, id(r) as edge_id, id(b) as target_id
816
  RETURN r, b, edge_id, target_id
817
  """
818
- results = await session.run(query, entity_id=node.id)
819
 
820
  # Get all records and release database connection
821
- records = await results.fetch(
822
- 1000
823
- ) # Max neighbour nodes we can handled
824
  await results.consume() # Ensure results are consumed
825
 
826
- # Nodes not connected to start node need to check degree
827
- if current_depth > 1 and len(records) < min_degree:
828
- return
829
-
830
- # Add current node to result
831
- result.nodes.append(node)
832
- visited_nodes.add(node.id)
833
-
834
- # Add edge to result if it exists and not already added
835
- if edge and edge.id not in visited_edges:
836
- result.edges.append(edge)
837
- visited_edges.add(edge.id)
838
-
839
- # Prepare nodes and edges for recursive processing
840
- nodes_to_process = []
841
  for record in records:
842
  rel = record["r"]
843
  edge_id = str(record["edge_id"])
 
844
  if edge_id not in visited_edges:
845
  b_node = record["b"]
846
  target_id = b_node.get("entity_id")
@@ -849,55 +947,59 @@ class Neo4JStorage(BaseGraphStorage):
849
  # Create KnowledgeGraphNode for target
850
  target_node = KnowledgeGraphNode(
851
  id=f"{target_id}",
852
- labels=list(f"{target_id}"),
853
- properties=dict(b_node.properties),
854
  )
855
 
856
  # Create KnowledgeGraphEdge
857
  target_edge = KnowledgeGraphEdge(
858
  id=f"{edge_id}",
859
  type=rel.type,
860
- source=f"{node.id}",
861
  target=f"{target_id}",
862
  properties=dict(rel),
863
  )
864
 
865
- nodes_to_process.append((target_node, target_edge))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
866
  else:
867
  logger.warning(
868
- f"Skipping edge {edge_id} due to missing labels on target node"
869
  )
870
 
871
- # Process nodes after releasing database connection
872
- for target_node, target_edge in nodes_to_process:
873
- await traverse(target_node, target_edge, current_depth + 1)
874
-
875
- # Get the starting node's data
876
- async with self._driver.session(
877
- database=self._DATABASE, default_access_mode="READ"
878
- ) as session:
879
- query = """
880
- MATCH (n:base {entity_id: $entity_id})
881
- RETURN id(n) as node_id, n
882
- """
883
- node_result = await session.run(query, entity_id=node_label)
884
- try:
885
- node_record = await node_result.single()
886
- if not node_record:
887
- return result
888
-
889
- # Create initial KnowledgeGraphNode
890
- start_node = KnowledgeGraphNode(
891
- id=f"{node_record['n'].get('entity_id')}",
892
- labels=list(f"{node_record['n'].get('entity_id')}"),
893
- properties=dict(node_record["n"].properties),
894
- )
895
- finally:
896
- await node_result.consume() # Ensure results are consumed
897
-
898
- # Start traversal with the initial node
899
- await traverse(start_node, None, 0)
900
-
901
  return result
902
 
903
  async def get_all_labels(self) -> list[str]:
@@ -914,7 +1016,7 @@ class Neo4JStorage(BaseGraphStorage):
914
 
915
  # Method 2: Query compatible with older versions
916
  query = """
917
- MATCH (n)
918
  WHERE n.entity_id IS NOT NULL
919
  RETURN DISTINCT n.entity_id AS label
920
  ORDER BY label
@@ -1028,3 +1130,28 @@ class Neo4JStorage(BaseGraphStorage):
1028
  self, algorithm: str
1029
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
1030
  raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import inspect
2
  import os
3
  import re
4
  from dataclasses import dataclass
5
+ from typing import Any, final
6
  import numpy as np
7
  import configparser
8
 
 
28
  exceptions as neo4jExceptions,
29
  AsyncDriver,
30
  AsyncManagedTransaction,
 
31
  )
32
 
33
  config = configparser.ConfigParser()
 
50
  embedding_func=embedding_func,
51
  )
52
  self._driver = None
 
53
 
54
+ def __post_init__(self):
55
+ self._node_embed_algorithms = {
56
+ "node2vec": self._node2vec_embed,
57
+ }
58
+
59
+ async def initialize(self):
60
  URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
61
  USERNAME = os.environ.get(
62
  "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
 
89
  ),
90
  )
91
  DATABASE = os.environ.get(
92
+ "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
93
  )
94
 
95
  self._driver: AsyncDriver = AsyncGraphDatabase.driver(
 
101
  max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
102
  )
103
 
104
+ # Try to connect to the database and create it if it doesn't exist
105
+ for database in (DATABASE, None):
106
+ self._DATABASE = database
107
+ connected = False
 
 
 
 
 
 
 
108
 
109
+ try:
110
+ async with self._driver.session(database=database) as session:
111
+ try:
112
+ result = await session.run("MATCH (n) RETURN n LIMIT 0")
113
+ await result.consume() # Ensure result is consumed
114
+ logger.info(f"Connected to {database} at {URI}")
115
+ connected = True
116
+ except neo4jExceptions.ServiceUnavailable as e:
117
+ logger.error(
118
+ f"{database} at {URI} is not available".capitalize()
119
+ )
120
+ raise e
121
+ except neo4jExceptions.AuthError as e:
122
+ logger.error(f"Authentication failed for {database} at {URI}")
123
+ raise e
124
+ except neo4jExceptions.ClientError as e:
125
+ if e.code == "Neo.ClientError.Database.DatabaseNotFound":
126
+ logger.info(
127
+ f"{database} at {URI} not found. Try to create specified database.".capitalize()
128
+ )
129
+ try:
130
+ async with self._driver.session() as session:
131
+ result = await session.run(
132
+ f"CREATE DATABASE `{database}` IF NOT EXISTS"
133
  )
134
+ await result.consume() # Ensure result is consumed
135
+ logger.info(f"{database} at {URI} created".capitalize())
136
+ connected = True
137
+ except (
138
+ neo4jExceptions.ClientError,
139
+ neo4jExceptions.DatabaseError,
140
+ ) as e:
141
+ if (
142
+ e.code
143
+ == "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
144
+ ) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
145
+ if database is not None:
146
+ logger.warning(
147
+ "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
148
+ )
149
+ if database is None:
150
+ logger.error(f"Failed to create {database} at {URI}")
151
  raise e
152
+
153
+ if connected:
154
+ # Create index for base nodes on entity_id if it doesn't exist
155
+ try:
156
+ async with self._driver.session(database=database) as session:
157
+ # Check if index exists first
158
+ check_query = """
159
+ CALL db.indexes() YIELD name, labelsOrTypes, properties
160
+ WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
161
+ RETURN count(*) > 0 AS exists
162
+ """
163
  try:
164
+ check_result = await session.run(check_query)
165
+ record = await check_result.single()
166
+ await check_result.consume()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ index_exists = record and record.get("exists", False)
 
169
 
170
+ if not index_exists:
171
+ # Create index only if it doesn't exist
172
+ result = await session.run(
173
+ "CREATE INDEX FOR (n:base) ON (n.entity_id)"
174
+ )
175
+ await result.consume()
176
+ logger.info(
177
+ f"Created index for base nodes on entity_id in {database}"
178
+ )
179
+ except Exception:
180
+ # Fallback if db.indexes() is not supported in this Neo4j version
181
+ result = await session.run(
182
+ "CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
183
+ )
184
+ await result.consume()
185
+ except Exception as e:
186
+ logger.warning(f"Failed to create index: {str(e)}")
187
+ break
188
 
189
+ async def finalize(self):
190
  """Close the Neo4j driver and release all resources"""
191
  if self._driver:
192
  await self._driver.close()
 
194
 
195
  async def __aexit__(self, exc_type, exc, tb):
196
  """Ensure driver is closed when context manager exits"""
197
+ await self.finalize()
198
 
199
  async def index_done_callback(self) -> None:
200
  # Noe4J handles persistence automatically
 
267
  raise
268
 
269
  async def get_node(self, node_id: str) -> dict[str, str] | None:
270
+ """Get node by its label identifier, return only node properties
271
 
272
  Args:
273
  node_id: The node label to look up
 
452
  logger.debug(
453
  f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
454
  )
455
+ # Return None when no edge found
456
+ return None
 
 
 
 
 
457
  finally:
458
  await result.consume() # Ensure result is fully consumed
459
 
 
545
  """
546
  properties = node_data
547
  entity_type = properties["entity_type"]
 
548
  if "entity_id" not in properties:
549
  raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
550
 
 
554
  async def execute_upsert(tx: AsyncManagedTransaction):
555
  query = (
556
  """
557
+ MERGE (n:base {entity_id: $entity_id})
558
  SET n += $properties
559
  SET n:`%s`
560
  """
561
  % entity_type
562
  )
563
+ result = await tx.run(
564
+ query, entity_id=node_id, properties=properties
565
+ )
566
  logger.debug(
567
+ f"Upserted node with entity_id '{node_id}' and properties: {properties}"
568
  )
569
  await result.consume() # Ensure result is fully consumed
570
 
 
642
  self,
643
  node_label: str,
644
  max_depth: int = 3,
645
+ max_nodes: int = MAX_GRAPH_NODES,
 
646
  ) -> KnowledgeGraph:
647
  """
648
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
 
 
 
 
 
 
649
 
650
  Args:
651
+ node_label: Label of the starting node, * means all nodes
652
+ max_depth: Maximum depth of the subgraph, Defaults to 3
653
+ max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
654
+
655
  Returns:
656
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
657
+ indicating whether the graph was truncated due to max_nodes limit
658
  """
659
  result = KnowledgeGraph()
660
  seen_nodes = set()
 
665
  ) as session:
666
  try:
667
  if node_label == "*":
668
+ # First check total node count to determine if graph is truncated
669
+ count_query = "MATCH (n) RETURN count(n) as total"
670
+ count_result = None
671
+ try:
672
+ count_result = await session.run(count_query)
673
+ count_record = await count_result.single()
674
+
675
+ if count_record and count_record["total"] > max_nodes:
676
+ result.is_truncated = True
677
+ logger.info(
678
+ f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
679
+ )
680
+ finally:
681
+ if count_result:
682
+ await count_result.consume()
683
+
684
+ # Run main query to get nodes with highest degree
685
  main_query = """
686
  MATCH (n)
687
  OPTIONAL MATCH (n)-[r]-()
688
  WITH n, COALESCE(count(r), 0) AS degree
 
689
  ORDER BY degree DESC
690
  LIMIT $max_nodes
691
  WITH collect({node: n}) AS filtered_nodes
 
696
  RETURN filtered_nodes AS node_info,
697
  collect(DISTINCT r) AS relationships
698
  """
699
+ result_set = None
700
+ try:
701
+ result_set = await session.run(
702
+ main_query,
703
+ {"max_nodes": max_nodes},
704
+ )
705
+ record = await result_set.single()
706
+ finally:
707
+ if result_set:
708
+ await result_set.consume()
709
 
710
  else:
711
+ # return await self._robust_fallback(node_label, max_depth, max_nodes)
712
+ # First try without limit to check if we need to truncate
713
+ full_query = """
714
  MATCH (start)
715
+ WHERE start.entity_id = $entity_id
 
 
 
 
716
  WITH start
717
  CALL apoc.path.subgraphAll(start, {
718
  relationshipFilter: '',
 
721
  bfs: true
722
  })
723
  YIELD nodes, relationships
724
+ WITH nodes, relationships, size(nodes) AS total_nodes
725
  UNWIND nodes AS node
726
+ WITH collect({node: node}) AS node_info, relationships, total_nodes
727
+ RETURN node_info, relationships, total_nodes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  """
 
 
 
 
 
 
 
 
 
 
729
 
730
+ # Try to get full result
731
+ full_result = None
732
+ try:
733
+ full_result = await session.run(
734
+ full_query,
735
+ {
736
+ "entity_id": node_label,
737
+ "max_depth": max_depth,
738
+ },
739
+ )
740
+ full_record = await full_result.single()
741
+
742
+ # If no record found, return empty KnowledgeGraph
743
+ if not full_record:
744
+ logger.debug(f"No nodes found for entity_id: {node_label}")
745
+ return result
746
+
747
+ # If record found, check node count
748
+ total_nodes = full_record["total_nodes"]
749
+
750
+ if total_nodes <= max_nodes:
751
+ # If node count is within limit, use full result directly
752
+ logger.debug(
753
+ f"Using full result with {total_nodes} nodes (no truncation needed)"
754
+ )
755
+ record = full_record
756
+ else:
757
+ # If node count exceeds limit, set truncated flag and run limited query
758
+ result.is_truncated = True
759
+ logger.info(
760
+ f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
761
+ )
762
+
763
+ # Run limited query
764
+ limited_query = """
765
+ MATCH (start)
766
+ WHERE start.entity_id = $entity_id
767
+ WITH start
768
+ CALL apoc.path.subgraphAll(start, {
769
+ relationshipFilter: '',
770
+ minLevel: 0,
771
+ maxLevel: $max_depth,
772
+ limit: $max_nodes,
773
+ bfs: true
774
+ })
775
+ YIELD nodes, relationships
776
+ UNWIND nodes AS node
777
+ WITH collect({node: node}) AS node_info, relationships
778
+ RETURN node_info, relationships
779
+ """
780
+ result_set = None
781
+ try:
782
+ result_set = await session.run(
783
+ limited_query,
784
+ {
785
+ "entity_id": node_label,
786
+ "max_depth": max_depth,
787
+ "max_nodes": max_nodes,
788
+ },
789
  )
790
+ record = await result_set.single()
791
+ finally:
792
+ if result_set:
793
+ await result_set.consume()
794
+ finally:
795
+ if full_result:
796
+ await full_result.consume()
797
+
798
+ if record:
799
+ # Handle nodes (compatible with multi-label cases)
800
+ for node_info in record["node_info"]:
801
+ node = node_info["node"]
802
+ node_id = node.id
803
+ if node_id not in seen_nodes:
804
+ result.nodes.append(
805
+ KnowledgeGraphNode(
806
+ id=f"{node_id}",
807
+ labels=[node.get("entity_id")],
808
+ properties=dict(node),
809
+ )
810
+ )
811
+ seen_nodes.add(node_id)
812
+
813
+ # Handle relationships (including direction information)
814
+ for rel in record["relationships"]:
815
+ edge_id = rel.id
816
+ if edge_id not in seen_edges:
817
+ start = rel.start_node
818
+ end = rel.end_node
819
+ result.edges.append(
820
+ KnowledgeGraphEdge(
821
+ id=f"{edge_id}",
822
+ type=rel.type,
823
+ source=f"{start.id}",
824
+ target=f"{end.id}",
825
+ properties=dict(rel),
826
  )
827
+ )
828
+ seen_edges.add(edge_id)
829
 
830
+ logger.info(
831
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
832
+ )
 
 
833
 
834
  except neo4jExceptions.ClientError as e:
835
  logger.warning(f"APOC plugin error: {str(e)}")
 
837
  logger.warning(
838
  "Neo4j: falling back to basic Cypher recursive search..."
839
  )
840
+ return await self._robust_fallback(node_label, max_depth, max_nodes)
841
+ else:
842
+ logger.warning(
843
+ "Neo4j: APOC plugin error with wildcard query, returning empty result"
 
 
844
  )
845
 
846
  return result
847
 
848
  async def _robust_fallback(
849
+ self, node_label: str, max_depth: int, max_nodes: int
850
  ) -> KnowledgeGraph:
851
  """
852
  Fallback implementation when APOC plugin is not available or incompatible.
853
  This method implements the same functionality as get_knowledge_graph but uses
854
+ only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
855
  """
856
+ from collections import deque
857
+
858
  result = KnowledgeGraph()
859
  visited_nodes = set()
860
  visited_edges = set()
861
+ visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
862
 
863
+ # Get the starting node's data
864
+ async with self._driver.session(
865
+ database=self._DATABASE, default_access_mode="READ"
866
+ ) as session:
867
+ query = """
868
+ MATCH (n:base {entity_id: $entity_id})
869
+ RETURN id(n) as node_id, n
870
+ """
871
+ node_result = await session.run(query, entity_id=node_label)
872
+ try:
873
+ node_record = await node_result.single()
874
+ if not node_record:
875
+ return result
876
+
877
+ # Create initial KnowledgeGraphNode
878
+ start_node = KnowledgeGraphNode(
879
+ id=f"{node_record['n'].get('entity_id')}",
880
+ labels=[node_record["n"].get("entity_id")],
881
+ properties=dict(node_record["n"]._properties),
882
+ )
883
+ finally:
884
+ await node_result.consume() # Ensure results are consumed
885
 
886
+ # Initialize queue for BFS with (node, edge, depth) tuples
887
+ # edge is None for the starting node
888
+ queue = deque([(start_node, None, 0)])
889
 
890
+ # True BFS implementation using a queue
891
+ while queue and len(visited_nodes) < max_nodes:
892
+ # Dequeue the next node to process
893
+ current_node, current_edge, current_depth = queue.popleft()
894
+
895
+ # Skip if already visited or exceeds max depth
896
+ if current_node.id in visited_nodes:
897
+ continue
898
+
899
+ if current_depth > max_depth:
900
+ logger.debug(
901
+ f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
902
+ )
903
+ continue
904
+
905
+ # Add current node to result
906
+ result.nodes.append(current_node)
907
+ visited_nodes.add(current_node.id)
908
+
909
+ # Add edge to result if it exists and not already added
910
+ if current_edge and current_edge.id not in visited_edges:
911
+ result.edges.append(current_edge)
912
+ visited_edges.add(current_edge.id)
913
+
914
+ # Stop if we've reached the node limit
915
+ if len(visited_nodes) >= max_nodes:
916
+ result.is_truncated = True
917
+ logger.info(
918
+ f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
919
+ )
920
+ break
921
+
922
+ # Get all edges and target nodes for the current node (even at max_depth)
923
  async with self._driver.session(
924
  database=self._DATABASE, default_access_mode="READ"
925
  ) as session:
 
928
  WITH r, b, id(r) as edge_id, id(b) as target_id
929
  RETURN r, b, edge_id, target_id
930
  """
931
+ results = await session.run(query, entity_id=current_node.id)
932
 
933
  # Get all records and release database connection
934
+ records = await results.fetch(1000) # Max neighbor nodes we can handle
 
 
935
  await results.consume() # Ensure results are consumed
936
 
937
+ # Process all neighbors - capture all edges but only queue unvisited nodes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
938
  for record in records:
939
  rel = record["r"]
940
  edge_id = str(record["edge_id"])
941
+
942
  if edge_id not in visited_edges:
943
  b_node = record["b"]
944
  target_id = b_node.get("entity_id")
 
947
  # Create KnowledgeGraphNode for target
948
  target_node = KnowledgeGraphNode(
949
  id=f"{target_id}",
950
+ labels=[target_id],
951
+ properties=dict(b_node._properties),
952
  )
953
 
954
  # Create KnowledgeGraphEdge
955
  target_edge = KnowledgeGraphEdge(
956
  id=f"{edge_id}",
957
  type=rel.type,
958
+ source=f"{current_node.id}",
959
  target=f"{target_id}",
960
  properties=dict(rel),
961
  )
962
 
963
+ # 对source_id和target_id进行排序,确保(A,B)和(B,A)被视为同一条边
964
+ sorted_pair = tuple(sorted([current_node.id, target_id]))
965
+
966
+ # 检查是否已存在相同的边(考虑无向性)
967
+ if sorted_pair not in visited_edge_pairs:
968
+ # 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
969
+ if target_id in visited_nodes or (
970
+ target_id not in visited_nodes
971
+ and current_depth < max_depth
972
+ ):
973
+ result.edges.append(target_edge)
974
+ visited_edges.add(edge_id)
975
+ visited_edge_pairs.add(sorted_pair)
976
+
977
+ # Only add unvisited nodes to the queue for further expansion
978
+ if target_id not in visited_nodes:
979
+ # Only add to queue if we're not at max depth yet
980
+ if current_depth < max_depth:
981
+ # Add node to queue with incremented depth
982
+ # Edge is already added to result, so we pass None as edge
983
+ queue.append((target_node, None, current_depth + 1))
984
+ else:
985
+ # At max depth, we've already added the edge but we don't add the node
986
+ # This prevents adding nodes beyond max_depth to the result
987
+ logger.debug(
988
+ f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
989
+ )
990
+ else:
991
+ # If target node already exists in result, we don't need to add it again
992
+ logger.debug(
993
+ f"Node {target_id} already visited, edge added but node not queued"
994
+ )
995
  else:
996
  logger.warning(
997
+ f"Skipping edge {edge_id} due to missing entity_id on target node"
998
  )
999
 
1000
+ logger.info(
1001
+ f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
1002
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1003
  return result
1004
 
1005
  async def get_all_labels(self) -> list[str]:
 
1016
 
1017
  # Method 2: Query compatible with older versions
1018
  query = """
1019
+ MATCH (n:base)
1020
  WHERE n.entity_id IS NOT NULL
1021
  RETURN DISTINCT n.entity_id AS label
1022
  ORDER BY label
 
1130
  self, algorithm: str
1131
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
1132
  raise NotImplementedError
1133
+
1134
+ async def drop(self) -> dict[str, str]:
1135
+ """Drop all data from storage and clean up resources
1136
+
1137
+ This method will delete all nodes and relationships in the Neo4j database.
1138
+
1139
+ Returns:
1140
+ dict[str, str]: Operation status and message
1141
+ - On success: {"status": "success", "message": "data dropped"}
1142
+ - On failure: {"status": "error", "message": "<error details>"}
1143
+ """
1144
+ try:
1145
+ async with self._driver.session(database=self._DATABASE) as session:
1146
+ # Delete all nodes and relationships
1147
+ query = "MATCH (n) DETACH DELETE n"
1148
+ result = await session.run(query)
1149
+ await result.consume() # Ensure result is fully consumed
1150
+
1151
+ logger.info(
1152
+ f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
1153
+ )
1154
+ return {"status": "success", "message": "data dropped"}
1155
+ except Exception as e:
1156
+ logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
1157
+ return {"status": "error", "message": str(e)}
lightrag/kg/networkx_impl.py CHANGED
@@ -42,6 +42,7 @@ class NetworkXStorage(BaseGraphStorage):
42
  )
43
  nx.write_graphml(graph, file_name)
44
 
 
45
  @staticmethod
46
  def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
47
  """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
@@ -155,16 +156,34 @@ class NetworkXStorage(BaseGraphStorage):
155
  return None
156
 
157
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
 
 
 
 
 
 
158
  graph = await self._get_graph()
159
  graph.add_node(node_id, **node_data)
160
 
161
  async def upsert_edge(
162
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
163
  ) -> None:
 
 
 
 
 
 
164
  graph = await self._get_graph()
165
  graph.add_edge(source_node_id, target_node_id, **edge_data)
166
 
167
  async def delete_node(self, node_id: str) -> None:
 
 
 
 
 
 
168
  graph = await self._get_graph()
169
  if graph.has_node(node_id):
170
  graph.remove_node(node_id)
@@ -172,6 +191,7 @@ class NetworkXStorage(BaseGraphStorage):
172
  else:
173
  logger.warning(f"Node {node_id} not found in the graph for deletion.")
174
 
 
175
  async def embed_nodes(
176
  self, algorithm: str
177
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
@@ -192,6 +212,11 @@ class NetworkXStorage(BaseGraphStorage):
192
  async def remove_nodes(self, nodes: list[str]):
193
  """Delete multiple nodes
194
 
 
 
 
 
 
195
  Args:
196
  nodes: List of node IDs to be deleted
197
  """
@@ -203,6 +228,11 @@ class NetworkXStorage(BaseGraphStorage):
203
  async def remove_edges(self, edges: list[tuple[str, str]]):
204
  """Delete multiple edges
205
 
 
 
 
 
 
206
  Args:
207
  edges: List of edges to be deleted, each edge is a (source, target) tuple
208
  """
@@ -229,118 +259,81 @@ class NetworkXStorage(BaseGraphStorage):
229
  self,
230
  node_label: str,
231
  max_depth: int = 3,
232
- min_degree: int = 0,
233
- inclusive: bool = False,
234
  ) -> KnowledgeGraph:
235
  """
236
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
237
- Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
238
- When reducing the number of nodes, the prioritization criteria are as follows:
239
- 1. min_degree does not affect nodes directly connected to the matching nodes
240
- 2. Label matching nodes take precedence
241
- 3. Followed by nodes directly connected to the matching nodes
242
- 4. Finally, the degree of the nodes
243
 
244
  Args:
245
- node_label: Label of the starting node
246
- max_depth: Maximum depth of the subgraph
247
- min_degree: Minimum degree of nodes to include. Defaults to 0
248
- inclusive: Do an inclusive search if true
249
 
250
  Returns:
251
- KnowledgeGraph object containing nodes and edges
 
252
  """
253
- result = KnowledgeGraph()
254
- seen_nodes = set()
255
- seen_edges = set()
256
-
257
  graph = await self._get_graph()
258
 
259
- # Initialize sets for start nodes and direct connected nodes
260
- start_nodes = set()
261
- direct_connected_nodes = set()
262
 
263
  # Handle special case for "*" label
264
  if node_label == "*":
265
- # For "*", return the entire graph including all nodes and edges
266
- subgraph = (
267
- graph.copy()
268
- ) # Create a copy to avoid modifying the original graph
 
 
 
 
 
 
 
 
 
 
 
269
  else:
270
- # Find nodes with matching node id based on search_mode
271
- nodes_to_explore = []
272
- for n, attr in graph.nodes(data=True):
273
- node_str = str(n)
274
- if not inclusive:
275
- if node_label == node_str: # Use exact matching
276
- nodes_to_explore.append(n)
277
- else: # inclusive mode
278
- if node_label in node_str: # Use partial matching
279
- nodes_to_explore.append(n)
280
-
281
- if not nodes_to_explore:
282
- logger.warning(f"No nodes found with label {node_label}")
283
- return result
284
-
285
- # Get subgraph using ego_graph from all matching nodes
286
- combined_subgraph = nx.Graph()
287
- for start_node in nodes_to_explore:
288
- node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
289
- combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
290
-
291
- # Get start nodes and direct connected nodes
292
- if nodes_to_explore:
293
- start_nodes = set(nodes_to_explore)
294
- # Get nodes directly connected to all start nodes
295
- for start_node in start_nodes:
296
- direct_connected_nodes.update(
297
- combined_subgraph.neighbors(start_node)
298
- )
299
-
300
- # Remove start nodes from directly connected nodes (avoid duplicates)
301
- direct_connected_nodes -= start_nodes
302
-
303
- subgraph = combined_subgraph
304
-
305
- # Filter nodes based on min_degree, but keep start nodes and direct connected nodes
306
- if min_degree > 0:
307
- nodes_to_keep = [
308
- node
309
- for node, degree in subgraph.degree()
310
- if node in start_nodes
311
- or node in direct_connected_nodes
312
- or degree >= min_degree
313
- ]
314
- subgraph = subgraph.subgraph(nodes_to_keep)
315
-
316
- # Check if number of nodes exceeds max_graph_nodes
317
- if len(subgraph.nodes()) > MAX_GRAPH_NODES:
318
- origin_nodes = len(subgraph.nodes())
319
- node_degrees = dict(subgraph.degree())
320
-
321
- def priority_key(node_item):
322
- node, degree = node_item
323
- # Priority order: start(2) > directly connected(1) > other nodes(0)
324
- if node in start_nodes:
325
- priority = 2
326
- elif node in direct_connected_nodes:
327
- priority = 1
328
- else:
329
- priority = 0
330
- return (priority, degree)
331
-
332
- # Sort by priority and degree and select top MAX_GRAPH_NODES nodes
333
- top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
334
- :MAX_GRAPH_NODES
335
- ]
336
- top_node_ids = [node[0] for node in top_nodes]
337
- # Create new subgraph and keep nodes only with most degree
338
- subgraph = subgraph.subgraph(top_node_ids)
339
- logger.info(
340
- f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
341
- )
342
 
343
  # Add nodes to result
 
 
344
  for node in subgraph.nodes():
345
  if str(node) in seen_nodes:
346
  continue
@@ -368,7 +361,7 @@ class NetworkXStorage(BaseGraphStorage):
368
  for edge in subgraph.edges():
369
  source, target = edge
370
  # Esure unique edge_id for undirect graph
371
- if source > target:
372
  source, target = target, source
373
  edge_id = f"{source}-{target}"
374
  if edge_id in seen_edges:
@@ -424,3 +417,35 @@ class NetworkXStorage(BaseGraphStorage):
424
  return False # Return error
425
 
426
  return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
  nx.write_graphml(graph, file_name)
44
 
45
+ # TODO:deprecated, remove later
46
  @staticmethod
47
  def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
48
  """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
 
156
  return None
157
 
158
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
159
+ """
160
+ Importance notes:
161
+ 1. Changes will be persisted to disk during the next index_done_callback
162
+ 2. Only one process should updating the storage at a time before index_done_callback,
163
+ KG-storage-log should be used to avoid data corruption
164
+ """
165
  graph = await self._get_graph()
166
  graph.add_node(node_id, **node_data)
167
 
168
  async def upsert_edge(
169
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
170
  ) -> None:
171
+ """
172
+ Importance notes:
173
+ 1. Changes will be persisted to disk during the next index_done_callback
174
+ 2. Only one process should updating the storage at a time before index_done_callback,
175
+ KG-storage-log should be used to avoid data corruption
176
+ """
177
  graph = await self._get_graph()
178
  graph.add_edge(source_node_id, target_node_id, **edge_data)
179
 
180
  async def delete_node(self, node_id: str) -> None:
181
+ """
182
+ Importance notes:
183
+ 1. Changes will be persisted to disk during the next index_done_callback
184
+ 2. Only one process should updating the storage at a time before index_done_callback,
185
+ KG-storage-log should be used to avoid data corruption
186
+ """
187
  graph = await self._get_graph()
188
  if graph.has_node(node_id):
189
  graph.remove_node(node_id)
 
191
  else:
192
  logger.warning(f"Node {node_id} not found in the graph for deletion.")
193
 
194
+ # TODO: NOT USED
195
  async def embed_nodes(
196
  self, algorithm: str
197
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
 
212
  async def remove_nodes(self, nodes: list[str]):
213
  """Delete multiple nodes
214
 
215
+ Importance notes:
216
+ 1. Changes will be persisted to disk during the next index_done_callback
217
+ 2. Only one process should updating the storage at a time before index_done_callback,
218
+ KG-storage-log should be used to avoid data corruption
219
+
220
  Args:
221
  nodes: List of node IDs to be deleted
222
  """
 
228
  async def remove_edges(self, edges: list[tuple[str, str]]):
229
  """Delete multiple edges
230
 
231
+ Importance notes:
232
+ 1. Changes will be persisted to disk during the next index_done_callback
233
+ 2. Only one process should updating the storage at a time before index_done_callback,
234
+ KG-storage-log should be used to avoid data corruption
235
+
236
  Args:
237
  edges: List of edges to be deleted, each edge is a (source, target) tuple
238
  """
 
259
  self,
260
  node_label: str,
261
  max_depth: int = 3,
262
+ max_nodes: int = MAX_GRAPH_NODES,
 
263
  ) -> KnowledgeGraph:
264
  """
265
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
 
 
 
 
 
 
266
 
267
  Args:
268
+ node_label: Label of the starting node,* means all nodes
269
+ max_depth: Maximum depth of the subgraph, Defaults to 3
270
+ max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
 
271
 
272
  Returns:
273
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
274
+ indicating whether the graph was truncated due to max_nodes limit
275
  """
 
 
 
 
276
  graph = await self._get_graph()
277
 
278
+ result = KnowledgeGraph()
 
 
279
 
280
  # Handle special case for "*" label
281
  if node_label == "*":
282
+ # Get degrees of all nodes
283
+ degrees = dict(graph.degree())
284
+ # Sort nodes by degree in descending order and take top max_nodes
285
+ sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
286
+
287
+ # Check if graph is truncated
288
+ if len(sorted_nodes) > max_nodes:
289
+ result.is_truncated = True
290
+ logger.info(
291
+ f"Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
292
+ )
293
+
294
+ limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
295
+ # Create subgraph with the highest degree nodes
296
+ subgraph = graph.subgraph(limited_nodes)
297
  else:
298
+ # Check if node exists
299
+ if node_label not in graph:
300
+ logger.warning(f"Node {node_label} not found in the graph")
301
+ return KnowledgeGraph() # Return empty graph
302
+
303
+ # Use BFS to get nodes
304
+ bfs_nodes = []
305
+ visited = set()
306
+ queue = [(node_label, 0)] # (node, depth) tuple
307
+
308
+ # Breadth-first search
309
+ while queue and len(bfs_nodes) < max_nodes:
310
+ current, depth = queue.pop(0)
311
+ if current not in visited:
312
+ visited.add(current)
313
+ bfs_nodes.append(current)
314
+
315
+ # Only explore neighbors if we haven't reached max_depth
316
+ if depth < max_depth:
317
+ # Add neighbor nodes to queue with incremented depth
318
+ neighbors = list(graph.neighbors(current))
319
+ queue.extend(
320
+ [(n, depth + 1) for n in neighbors if n not in visited]
321
+ )
322
+
323
+ # Check if graph is truncated - if we still have nodes in the queue
324
+ # and we've reached max_nodes, then the graph is truncated
325
+ if queue and len(bfs_nodes) >= max_nodes:
326
+ result.is_truncated = True
327
+ logger.info(
328
+ f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
329
+ )
330
+
331
+ # Create subgraph with BFS discovered nodes
332
+ subgraph = graph.subgraph(bfs_nodes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  # Add nodes to result
335
+ seen_nodes = set()
336
+ seen_edges = set()
337
  for node in subgraph.nodes():
338
  if str(node) in seen_nodes:
339
  continue
 
361
  for edge in subgraph.edges():
362
  source, target = edge
363
  # Esure unique edge_id for undirect graph
364
+ if str(source) > str(target):
365
  source, target = target, source
366
  edge_id = f"{source}-{target}"
367
  if edge_id in seen_edges:
 
417
  return False # Return error
418
 
419
  return True
420
+
421
+ async def drop(self) -> dict[str, str]:
422
+ """Drop all graph data from storage and clean up resources
423
+
424
+ This method will:
425
+ 1. Remove the graph storage file if it exists
426
+ 2. Reset the graph to an empty state
427
+ 3. Update flags to notify other processes
428
+ 4. Changes is persisted to disk immediately
429
+
430
+ Returns:
431
+ dict[str, str]: Operation status and message
432
+ - On success: {"status": "success", "message": "data dropped"}
433
+ - On failure: {"status": "error", "message": "<error details>"}
434
+ """
435
+ try:
436
+ async with self._storage_lock:
437
+ # delete _client_file_name
438
+ if os.path.exists(self._graphml_xml_file):
439
+ os.remove(self._graphml_xml_file)
440
+ self._graph = nx.Graph()
441
+ # Notify other processes that data has been updated
442
+ await set_all_update_flags(self.namespace)
443
+ # Reset own update flag to avoid self-reloading
444
+ self.storage_updated.value = False
445
+ logger.info(
446
+ f"Process {os.getpid()} drop graph {self.namespace} (file:{self._graphml_xml_file})"
447
+ )
448
+ return {"status": "success", "message": "data dropped"}
449
+ except Exception as e:
450
+ logger.error(f"Error dropping graph {self.namespace}: {e}")
451
+ return {"status": "error", "message": str(e)}
lightrag/kg/oracle_impl.py DELETED
@@ -1,1346 +0,0 @@
1
- import array
2
- import asyncio
3
-
4
- # import html
5
- import os
6
- from dataclasses import dataclass, field
7
- from typing import Any, Union, final
8
- import numpy as np
9
- import configparser
10
-
11
- from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
12
-
13
- from ..base import (
14
- BaseGraphStorage,
15
- BaseKVStorage,
16
- BaseVectorStorage,
17
- )
18
- from ..namespace import NameSpace, is_namespace
19
- from ..utils import logger
20
-
21
- import pipmaster as pm
22
-
23
- if not pm.is_installed("graspologic"):
24
- pm.install("graspologic")
25
-
26
- if not pm.is_installed("oracledb"):
27
- pm.install("oracledb")
28
-
29
- from graspologic import embed
30
- import oracledb
31
-
32
-
33
- class OracleDB:
34
- def __init__(self, config, **kwargs):
35
- self.host = config.get("host", None)
36
- self.port = config.get("port", None)
37
- self.user = config.get("user", None)
38
- self.password = config.get("password", None)
39
- self.dsn = config.get("dsn", None)
40
- self.config_dir = config.get("config_dir", None)
41
- self.wallet_location = config.get("wallet_location", None)
42
- self.wallet_password = config.get("wallet_password", None)
43
- self.workspace = config.get("workspace", None)
44
- self.max = 12
45
- self.increment = 1
46
- logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
47
- if self.user is None or self.password is None:
48
- raise ValueError("Missing database user or password")
49
-
50
- try:
51
- oracledb.defaults.fetch_lobs = False
52
-
53
- self.pool = oracledb.create_pool_async(
54
- user=self.user,
55
- password=self.password,
56
- dsn=self.dsn,
57
- config_dir=self.config_dir,
58
- wallet_location=self.wallet_location,
59
- wallet_password=self.wallet_password,
60
- min=1,
61
- max=self.max,
62
- increment=self.increment,
63
- )
64
- logger.info(f"Connected to Oracle database at {self.dsn}")
65
- except Exception as e:
66
- logger.error(f"Failed to connect to Oracle database at {self.dsn}")
67
- logger.error(f"Oracle database error: {e}")
68
- raise
69
-
70
- def numpy_converter_in(self, value):
71
- """Convert numpy array to array.array"""
72
- if value.dtype == np.float64:
73
- dtype = "d"
74
- elif value.dtype == np.float32:
75
- dtype = "f"
76
- else:
77
- dtype = "b"
78
- return array.array(dtype, value)
79
-
80
- def input_type_handler(self, cursor, value, arraysize):
81
- """Set the type handler for the input data"""
82
- if isinstance(value, np.ndarray):
83
- return cursor.var(
84
- oracledb.DB_TYPE_VECTOR,
85
- arraysize=arraysize,
86
- inconverter=self.numpy_converter_in,
87
- )
88
-
89
- def numpy_converter_out(self, value):
90
- """Convert array.array to numpy array"""
91
- if value.typecode == "b":
92
- dtype = np.int8
93
- elif value.typecode == "f":
94
- dtype = np.float32
95
- else:
96
- dtype = np.float64
97
- return np.array(value, copy=False, dtype=dtype)
98
-
99
- def output_type_handler(self, cursor, metadata):
100
- """Set the type handler for the output data"""
101
- if metadata.type_code is oracledb.DB_TYPE_VECTOR:
102
- return cursor.var(
103
- metadata.type_code,
104
- arraysize=cursor.arraysize,
105
- outconverter=self.numpy_converter_out,
106
- )
107
-
108
- async def check_tables(self):
109
- for k, v in TABLES.items():
110
- try:
111
- if k.lower() == "lightrag_graph":
112
- await self.query(
113
- "SELECT id FROM GRAPH_TABLE (lightrag_graph MATCH (a) COLUMNS (a.id)) fetch first row only"
114
- )
115
- else:
116
- await self.query(f"SELECT 1 FROM {k}")
117
- except Exception as e:
118
- logger.error(f"Failed to check table {k} in Oracle database")
119
- logger.error(f"Oracle database error: {e}")
120
- try:
121
- # print(v["ddl"])
122
- await self.execute(v["ddl"])
123
- logger.info(f"Created table {k} in Oracle database")
124
- except Exception as e:
125
- logger.error(f"Failed to create table {k} in Oracle database")
126
- logger.error(f"Oracle database error: {e}")
127
-
128
- logger.info("Finished check all tables in Oracle database")
129
-
130
- async def query(
131
- self, sql: str, params: dict = None, multirows: bool = False
132
- ) -> Union[dict, None]:
133
- async with self.pool.acquire() as connection:
134
- connection.inputtypehandler = self.input_type_handler
135
- connection.outputtypehandler = self.output_type_handler
136
- with connection.cursor() as cursor:
137
- try:
138
- await cursor.execute(sql, params)
139
- except Exception as e:
140
- logger.error(f"Oracle database error: {e}")
141
- raise
142
- columns = [column[0].lower() for column in cursor.description]
143
- if multirows:
144
- rows = await cursor.fetchall()
145
- if rows:
146
- data = [dict(zip(columns, row)) for row in rows]
147
- else:
148
- data = []
149
- else:
150
- row = await cursor.fetchone()
151
- if row:
152
- data = dict(zip(columns, row))
153
- else:
154
- data = None
155
- return data
156
-
157
- async def execute(self, sql: str, data: Union[list, dict] = None):
158
- # logger.info("go into OracleDB execute method")
159
- try:
160
- async with self.pool.acquire() as connection:
161
- connection.inputtypehandler = self.input_type_handler
162
- connection.outputtypehandler = self.output_type_handler
163
- with connection.cursor() as cursor:
164
- if data is None:
165
- await cursor.execute(sql)
166
- else:
167
- await cursor.execute(sql, data)
168
- await connection.commit()
169
- except Exception as e:
170
- logger.error(f"Oracle database error: {e}")
171
- raise
172
-
173
-
174
- class ClientManager:
175
- _instances: dict[str, Any] = {"db": None, "ref_count": 0}
176
- _lock = asyncio.Lock()
177
-
178
- @staticmethod
179
- def get_config() -> dict[str, Any]:
180
- config = configparser.ConfigParser()
181
- config.read("config.ini", "utf-8")
182
-
183
- return {
184
- "user": os.environ.get(
185
- "ORACLE_USER",
186
- config.get("oracle", "user", fallback=None),
187
- ),
188
- "password": os.environ.get(
189
- "ORACLE_PASSWORD",
190
- config.get("oracle", "password", fallback=None),
191
- ),
192
- "dsn": os.environ.get(
193
- "ORACLE_DSN",
194
- config.get("oracle", "dsn", fallback=None),
195
- ),
196
- "config_dir": os.environ.get(
197
- "ORACLE_CONFIG_DIR",
198
- config.get("oracle", "config_dir", fallback=None),
199
- ),
200
- "wallet_location": os.environ.get(
201
- "ORACLE_WALLET_LOCATION",
202
- config.get("oracle", "wallet_location", fallback=None),
203
- ),
204
- "wallet_password": os.environ.get(
205
- "ORACLE_WALLET_PASSWORD",
206
- config.get("oracle", "wallet_password", fallback=None),
207
- ),
208
- "workspace": os.environ.get(
209
- "ORACLE_WORKSPACE",
210
- config.get("oracle", "workspace", fallback="default"),
211
- ),
212
- }
213
-
214
- @classmethod
215
- async def get_client(cls) -> OracleDB:
216
- async with cls._lock:
217
- if cls._instances["db"] is None:
218
- config = ClientManager.get_config()
219
- db = OracleDB(config)
220
- await db.check_tables()
221
- cls._instances["db"] = db
222
- cls._instances["ref_count"] = 0
223
- cls._instances["ref_count"] += 1
224
- return cls._instances["db"]
225
-
226
- @classmethod
227
- async def release_client(cls, db: OracleDB):
228
- async with cls._lock:
229
- if db is not None:
230
- if db is cls._instances["db"]:
231
- cls._instances["ref_count"] -= 1
232
- if cls._instances["ref_count"] == 0:
233
- await db.pool.close()
234
- logger.info("Closed OracleDB database connection pool")
235
- cls._instances["db"] = None
236
- else:
237
- await db.pool.close()
238
-
239
-
240
- @final
241
- @dataclass
242
- class OracleKVStorage(BaseKVStorage):
243
- db: OracleDB = field(default=None)
244
- meta_fields = None
245
-
246
- def __post_init__(self):
247
- self._data = {}
248
- self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
249
-
250
- async def initialize(self):
251
- if self.db is None:
252
- self.db = await ClientManager.get_client()
253
-
254
- async def finalize(self):
255
- if self.db is not None:
256
- await ClientManager.release_client(self.db)
257
- self.db = None
258
-
259
- ################ QUERY METHODS ################
260
-
261
- async def get_by_id(self, id: str) -> dict[str, Any] | None:
262
- """Get doc_full data based on id."""
263
- SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
264
- params = {"workspace": self.db.workspace, "id": id}
265
- # print("get_by_id:"+SQL)
266
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
267
- array_res = await self.db.query(SQL, params, multirows=True)
268
- res = {}
269
- for row in array_res:
270
- res[row["id"]] = row
271
- if res:
272
- return res
273
- else:
274
- return None
275
- else:
276
- return await self.db.query(SQL, params)
277
-
278
- async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
279
- """Specifically for llm_response_cache."""
280
- SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
281
- params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
282
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
283
- array_res = await self.db.query(SQL, params, multirows=True)
284
- res = {}
285
- for row in array_res:
286
- res[row["id"]] = row
287
- return res
288
- else:
289
- return None
290
-
291
- async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
292
- """Get doc_chunks data based on id"""
293
- SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
294
- ids=",".join([f"'{id}'" for id in ids])
295
- )
296
- params = {"workspace": self.db.workspace}
297
- # print("get_by_ids:"+SQL)
298
- res = await self.db.query(SQL, params, multirows=True)
299
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
300
- modes = set()
301
- dict_res: dict[str, dict] = {}
302
- for row in res:
303
- modes.add(row["mode"])
304
- for mode in modes:
305
- if mode not in dict_res:
306
- dict_res[mode] = {}
307
- for row in res:
308
- dict_res[row["mode"]][row["id"]] = row
309
- res = [{k: v} for k, v in dict_res.items()]
310
- return res
311
-
312
- async def filter_keys(self, keys: set[str]) -> set[str]:
313
- """Return keys that don't exist in storage"""
314
- SQL = SQL_TEMPLATES["filter_keys"].format(
315
- table_name=namespace_to_table_name(self.namespace),
316
- ids=",".join([f"'{id}'" for id in keys]),
317
- )
318
- params = {"workspace": self.db.workspace}
319
- res = await self.db.query(SQL, params, multirows=True)
320
- if res:
321
- exist_keys = [key["id"] for key in res]
322
- data = set([s for s in keys if s not in exist_keys])
323
- return data
324
- else:
325
- return set(keys)
326
-
327
- ################ INSERT METHODS ################
328
- async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
329
- logger.info(f"Inserting {len(data)} to {self.namespace}")
330
- if not data:
331
- return
332
-
333
- if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
334
- list_data = [
335
- {
336
- "id": k,
337
- **{k1: v1 for k1, v1 in v.items()},
338
- }
339
- for k, v in data.items()
340
- ]
341
- contents = [v["content"] for v in data.values()]
342
- batches = [
343
- contents[i : i + self._max_batch_size]
344
- for i in range(0, len(contents), self._max_batch_size)
345
- ]
346
- embeddings_list = await asyncio.gather(
347
- *[self.embedding_func(batch) for batch in batches]
348
- )
349
- embeddings = np.concatenate(embeddings_list)
350
- for i, d in enumerate(list_data):
351
- d["__vector__"] = embeddings[i]
352
-
353
- merge_sql = SQL_TEMPLATES["merge_chunk"]
354
- for item in list_data:
355
- _data = {
356
- "id": item["id"],
357
- "content": item["content"],
358
- "workspace": self.db.workspace,
359
- "tokens": item["tokens"],
360
- "chunk_order_index": item["chunk_order_index"],
361
- "full_doc_id": item["full_doc_id"],
362
- "content_vector": item["__vector__"],
363
- "status": item["status"],
364
- }
365
- await self.db.execute(merge_sql, _data)
366
- if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
367
- for k, v in data.items():
368
- # values.clear()
369
- merge_sql = SQL_TEMPLATES["merge_doc_full"]
370
- _data = {
371
- "id": k,
372
- "content": v["content"],
373
- "workspace": self.db.workspace,
374
- }
375
- await self.db.execute(merge_sql, _data)
376
-
377
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
378
- for mode, items in data.items():
379
- for k, v in items.items():
380
- upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
381
- _data = {
382
- "workspace": self.db.workspace,
383
- "id": k,
384
- "original_prompt": v["original_prompt"],
385
- "return_value": v["return"],
386
- "cache_mode": mode,
387
- }
388
-
389
- await self.db.execute(upsert_sql, _data)
390
-
391
- async def index_done_callback(self) -> None:
392
- # Oracle handles persistence automatically
393
- pass
394
-
395
-
396
- @final
397
- @dataclass
398
- class OracleVectorDBStorage(BaseVectorStorage):
399
- db: OracleDB | None = field(default=None)
400
-
401
- def __post_init__(self):
402
- config = self.global_config.get("vector_db_storage_cls_kwargs", {})
403
- cosine_threshold = config.get("cosine_better_than_threshold")
404
- if cosine_threshold is None:
405
- raise ValueError(
406
- "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
407
- )
408
- self.cosine_better_than_threshold = cosine_threshold
409
-
410
- async def initialize(self):
411
- if self.db is None:
412
- self.db = await ClientManager.get_client()
413
-
414
- async def finalize(self):
415
- if self.db is not None:
416
- await ClientManager.release_client(self.db)
417
- self.db = None
418
-
419
- #################### query method ###############
420
- async def query(
421
- self, query: str, top_k: int, ids: list[str] | None = None
422
- ) -> list[dict[str, Any]]:
423
- embeddings = await self.embedding_func([query])
424
- embedding = embeddings[0]
425
- # 转换精度
426
- dtype = str(embedding.dtype).upper()
427
- dimension = embedding.shape[0]
428
- embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
429
-
430
- SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
431
- params = {
432
- "embedding_string": embedding_string,
433
- "workspace": self.db.workspace,
434
- "top_k": top_k,
435
- "better_than_threshold": self.cosine_better_than_threshold,
436
- }
437
- results = await self.db.query(SQL, params=params, multirows=True)
438
- return results
439
-
440
- async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
441
- raise NotImplementedError
442
-
443
- async def index_done_callback(self) -> None:
444
- # Oracles handles persistence automatically
445
- pass
446
-
447
- async def delete(self, ids: list[str]) -> None:
448
- """Delete vectors with specified IDs
449
-
450
- Args:
451
- ids: List of vector IDs to be deleted
452
- """
453
- if not ids:
454
- return
455
-
456
- try:
457
- SQL = SQL_TEMPLATES["delete_vectors"].format(
458
- ids=",".join([f"'{id}'" for id in ids])
459
- )
460
- params = {"workspace": self.db.workspace}
461
- await self.db.execute(SQL, params)
462
- logger.info(
463
- f"Successfully deleted {len(ids)} vectors from {self.namespace}"
464
- )
465
- except Exception as e:
466
- logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
467
- raise
468
-
469
- async def delete_entity(self, entity_name: str) -> None:
470
- """Delete entity by name
471
-
472
- Args:
473
- entity_name: Name of the entity to delete
474
- """
475
- try:
476
- SQL = SQL_TEMPLATES["delete_entity"]
477
- params = {"workspace": self.db.workspace, "entity_name": entity_name}
478
- await self.db.execute(SQL, params)
479
- logger.info(f"Successfully deleted entity {entity_name}")
480
- except Exception as e:
481
- logger.error(f"Error deleting entity {entity_name}: {e}")
482
- raise
483
-
484
- async def delete_entity_relation(self, entity_name: str) -> None:
485
- """Delete all relations connected to an entity
486
-
487
- Args:
488
- entity_name: Name of the entity whose relations should be deleted
489
- """
490
- try:
491
- SQL = SQL_TEMPLATES["delete_entity_relations"]
492
- params = {"workspace": self.db.workspace, "entity_name": entity_name}
493
- await self.db.execute(SQL, params)
494
- logger.info(f"Successfully deleted relations for entity {entity_name}")
495
- except Exception as e:
496
- logger.error(f"Error deleting relations for entity {entity_name}: {e}")
497
- raise
498
-
499
- async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
500
- """Search for records with IDs starting with a specific prefix.
501
-
502
- Args:
503
- prefix: The prefix to search for in record IDs
504
-
505
- Returns:
506
- List of records with matching ID prefixes
507
- """
508
- try:
509
- # Determine the appropriate table based on namespace
510
- table_name = namespace_to_table_name(self.namespace)
511
-
512
- # Create SQL query to find records with IDs starting with prefix
513
- search_sql = f"""
514
- SELECT * FROM {table_name}
515
- WHERE workspace = :workspace
516
- AND id LIKE :prefix_pattern
517
- ORDER BY id
518
- """
519
-
520
- params = {"workspace": self.db.workspace, "prefix_pattern": f"{prefix}%"}
521
-
522
- # Execute query and get results
523
- results = await self.db.query(search_sql, params, multirows=True)
524
-
525
- logger.debug(
526
- f"Found {len(results) if results else 0} records with prefix '{prefix}'"
527
- )
528
- return results or []
529
-
530
- except Exception as e:
531
- logger.error(f"Error searching records with prefix '{prefix}': {e}")
532
- return []
533
-
534
- async def get_by_id(self, id: str) -> dict[str, Any] | None:
535
- """Get vector data by its ID
536
-
537
- Args:
538
- id: The unique identifier of the vector
539
-
540
- Returns:
541
- The vector data if found, or None if not found
542
- """
543
- try:
544
- # Determine the table name based on namespace
545
- table_name = namespace_to_table_name(self.namespace)
546
- if not table_name:
547
- logger.error(f"Unknown namespace for ID lookup: {self.namespace}")
548
- return None
549
-
550
- # Create the appropriate ID field name based on namespace
551
- id_field = "entity_id" if "NODES" in table_name else "relation_id"
552
- if "CHUNKS" in table_name:
553
- id_field = "chunk_id"
554
-
555
- # Prepare and execute the query
556
- query = f"""
557
- SELECT * FROM {table_name}
558
- WHERE {id_field} = :id AND workspace = :workspace
559
- """
560
- params = {"id": id, "workspace": self.db.workspace}
561
-
562
- result = await self.db.query(query, params)
563
- return result
564
- except Exception as e:
565
- logger.error(f"Error retrieving vector data for ID {id}: {e}")
566
- return None
567
-
568
- async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
569
- """Get multiple vector data by their IDs
570
-
571
- Args:
572
- ids: List of unique identifiers
573
-
574
- Returns:
575
- List of vector data objects that were found
576
- """
577
- if not ids:
578
- return []
579
-
580
- try:
581
- # Determine the table name based on namespace
582
- table_name = namespace_to_table_name(self.namespace)
583
- if not table_name:
584
- logger.error(f"Unknown namespace for IDs lookup: {self.namespace}")
585
- return []
586
-
587
- # Create the appropriate ID field name based on namespace
588
- id_field = "entity_id" if "NODES" in table_name else "relation_id"
589
- if "CHUNKS" in table_name:
590
- id_field = "chunk_id"
591
-
592
- # Format the list of IDs for SQL IN clause
593
- ids_list = ", ".join([f"'{id}'" for id in ids])
594
-
595
- # Prepare and execute the query
596
- query = f"""
597
- SELECT * FROM {table_name}
598
- WHERE {id_field} IN ({ids_list}) AND workspace = :workspace
599
- """
600
- params = {"workspace": self.db.workspace}
601
-
602
- results = await self.db.query(query, params, multirows=True)
603
- return results or []
604
- except Exception as e:
605
- logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
606
- return []
607
-
608
-
609
- @final
610
- @dataclass
611
- class OracleGraphStorage(BaseGraphStorage):
612
- db: OracleDB = field(default=None)
613
-
614
- def __post_init__(self):
615
- self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
616
-
617
- async def initialize(self):
618
- if self.db is None:
619
- self.db = await ClientManager.get_client()
620
-
621
- async def finalize(self):
622
- if self.db is not None:
623
- await ClientManager.release_client(self.db)
624
- self.db = None
625
-
626
- #################### insert method ################
627
-
628
- async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
629
- entity_name = node_id
630
- entity_type = node_data["entity_type"]
631
- description = node_data["description"]
632
- source_id = node_data["source_id"]
633
- logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
634
-
635
- content = entity_name + description
636
- contents = [content]
637
- batches = [
638
- contents[i : i + self._max_batch_size]
639
- for i in range(0, len(contents), self._max_batch_size)
640
- ]
641
- embeddings_list = await asyncio.gather(
642
- *[self.embedding_func(batch) for batch in batches]
643
- )
644
- embeddings = np.concatenate(embeddings_list)
645
- content_vector = embeddings[0]
646
- merge_sql = SQL_TEMPLATES["merge_node"]
647
- data = {
648
- "workspace": self.db.workspace,
649
- "name": entity_name,
650
- "entity_type": entity_type,
651
- "description": description,
652
- "source_chunk_id": source_id,
653
- "content": content,
654
- "content_vector": content_vector,
655
- }
656
- await self.db.execute(merge_sql, data)
657
- # self._graph.add_node(node_id, **node_data)
658
-
659
- async def upsert_edge(
660
- self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
661
- ) -> None:
662
- """插入或更新边"""
663
- # print("go into upsert edge method")
664
- source_name = source_node_id
665
- target_name = target_node_id
666
- weight = edge_data["weight"]
667
- keywords = edge_data["keywords"]
668
- description = edge_data["description"]
669
- source_chunk_id = edge_data["source_id"]
670
- logger.debug(
671
- f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
672
- )
673
-
674
- content = keywords + source_name + target_name + description
675
- contents = [content]
676
- batches = [
677
- contents[i : i + self._max_batch_size]
678
- for i in range(0, len(contents), self._max_batch_size)
679
- ]
680
- embeddings_list = await asyncio.gather(
681
- *[self.embedding_func(batch) for batch in batches]
682
- )
683
- embeddings = np.concatenate(embeddings_list)
684
- content_vector = embeddings[0]
685
- merge_sql = SQL_TEMPLATES["merge_edge"]
686
- data = {
687
- "workspace": self.db.workspace,
688
- "source_name": source_name,
689
- "target_name": target_name,
690
- "weight": weight,
691
- "keywords": keywords,
692
- "description": description,
693
- "source_chunk_id": source_chunk_id,
694
- "content": content,
695
- "content_vector": content_vector,
696
- }
697
- # print(merge_sql)
698
- await self.db.execute(merge_sql, data)
699
- # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
700
-
701
- async def embed_nodes(
702
- self, algorithm: str
703
- ) -> tuple[np.ndarray[Any, Any], list[str]]:
704
- if algorithm not in self._node_embed_algorithms:
705
- raise ValueError(f"Node embedding algorithm {algorithm} not supported")
706
- return await self._node_embed_algorithms[algorithm]()
707
-
708
- async def _node2vec_embed(self):
709
- """为节点生成向量"""
710
- embeddings, nodes = embed.node2vec_embed(
711
- self._graph,
712
- **self.config["node2vec_params"],
713
- )
714
-
715
- nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
716
- return embeddings, nodes_ids
717
-
718
- async def index_done_callback(self) -> None:
719
- # Oracles handles persistence automatically
720
- pass
721
-
722
- #################### query method #################
723
- async def has_node(self, node_id: str) -> bool:
724
- """根据节点id检查节点是否存在"""
725
- SQL = SQL_TEMPLATES["has_node"]
726
- params = {"workspace": self.db.workspace, "node_id": node_id}
727
- res = await self.db.query(SQL, params)
728
- if res:
729
- # print("Node exist!",res)
730
- return True
731
- else:
732
- # print("Node not exist!")
733
- return False
734
-
735
- async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
736
- SQL = SQL_TEMPLATES["has_edge"]
737
- params = {
738
- "workspace": self.db.workspace,
739
- "source_node_id": source_node_id,
740
- "target_node_id": target_node_id,
741
- }
742
- res = await self.db.query(SQL, params)
743
- if res:
744
- # print("Edge exist!",res)
745
- return True
746
- else:
747
- # print("Edge not exist!")
748
- return False
749
-
750
- async def node_degree(self, node_id: str) -> int:
751
- SQL = SQL_TEMPLATES["node_degree"]
752
- params = {"workspace": self.db.workspace, "node_id": node_id}
753
- res = await self.db.query(SQL, params)
754
- if res:
755
- return res["degree"]
756
- else:
757
- return 0
758
-
759
- async def edge_degree(self, src_id: str, tgt_id: str) -> int:
760
- """根据源和目标节点id获取边的度"""
761
- degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
762
- return degree
763
-
764
- async def get_node(self, node_id: str) -> dict[str, str] | None:
765
- """根据节点id获取节点数据"""
766
- SQL = SQL_TEMPLATES["get_node"]
767
- params = {"workspace": self.db.workspace, "node_id": node_id}
768
- res = await self.db.query(SQL, params)
769
- if res:
770
- return res
771
- else:
772
- return None
773
-
774
- async def get_edge(
775
- self, source_node_id: str, target_node_id: str
776
- ) -> dict[str, str] | None:
777
- SQL = SQL_TEMPLATES["get_edge"]
778
- params = {
779
- "workspace": self.db.workspace,
780
- "source_node_id": source_node_id,
781
- "target_node_id": target_node_id,
782
- }
783
- res = await self.db.query(SQL, params)
784
- if res:
785
- # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
786
- return res
787
- else:
788
- # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id)
789
- return None
790
-
791
- async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
792
- if await self.has_node(source_node_id):
793
- SQL = SQL_TEMPLATES["get_node_edges"]
794
- params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
795
- res = await self.db.query(sql=SQL, params=params, multirows=True)
796
- if res:
797
- data = [(i["source_name"], i["target_name"]) for i in res]
798
- # print("Get node edge!",self.db.workspace, source_node_id,data)
799
- return data
800
- else:
801
- # print("Node Edge not exist!",self.db.workspace, source_node_id)
802
- return []
803
-
804
- async def get_all_nodes(self, limit: int):
805
- """查询所有节点"""
806
- SQL = SQL_TEMPLATES["get_all_nodes"]
807
- params = {"workspace": self.db.workspace, "limit": str(limit)}
808
- res = await self.db.query(sql=SQL, params=params, multirows=True)
809
- if res:
810
- return res
811
-
812
- async def get_all_edges(self, limit: int):
813
- """查询所有边"""
814
- SQL = SQL_TEMPLATES["get_all_edges"]
815
- params = {"workspace": self.db.workspace, "limit": str(limit)}
816
- res = await self.db.query(sql=SQL, params=params, multirows=True)
817
- if res:
818
- return res
819
-
820
- async def get_statistics(self):
821
- SQL = SQL_TEMPLATES["get_statistics"]
822
- params = {"workspace": self.db.workspace}
823
- res = await self.db.query(sql=SQL, params=params, multirows=True)
824
- if res:
825
- return res
826
-
827
- async def delete_node(self, node_id: str) -> None:
828
- """Delete a node from the graph
829
-
830
- Args:
831
- node_id: ID of the node to delete
832
- """
833
- try:
834
- # First delete all relations connected to this node
835
- delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"]
836
- params_relations = {"workspace": self.db.workspace, "entity_name": node_id}
837
- await self.db.execute(delete_relations_sql, params_relations)
838
-
839
- # Then delete the node itself
840
- delete_node_sql = SQL_TEMPLATES["delete_entity"]
841
- params_node = {"workspace": self.db.workspace, "entity_name": node_id}
842
- await self.db.execute(delete_node_sql, params_node)
843
-
844
- logger.info(
845
- f"Successfully deleted node {node_id} and all its relationships"
846
- )
847
- except Exception as e:
848
- logger.error(f"Error deleting node {node_id}: {e}")
849
- raise
850
-
851
- async def remove_nodes(self, nodes: list[str]) -> None:
852
- """Delete multiple nodes from the graph
853
-
854
- Args:
855
- nodes: List of node IDs to be deleted
856
- """
857
- if not nodes:
858
- return
859
-
860
- try:
861
- for node in nodes:
862
- # For each node, first delete all its relationships
863
- delete_relations_sql = SQL_TEMPLATES["delete_entity_relations"]
864
- params_relations = {"workspace": self.db.workspace, "entity_name": node}
865
- await self.db.execute(delete_relations_sql, params_relations)
866
-
867
- # Then delete the node itself
868
- delete_node_sql = SQL_TEMPLATES["delete_entity"]
869
- params_node = {"workspace": self.db.workspace, "entity_name": node}
870
- await self.db.execute(delete_node_sql, params_node)
871
-
872
- logger.info(
873
- f"Successfully deleted {len(nodes)} nodes and their relationships"
874
- )
875
- except Exception as e:
876
- logger.error(f"Error during batch node deletion: {e}")
877
- raise
878
-
879
- async def remove_edges(self, edges: list[tuple[str, str]]) -> None:
880
- """Delete multiple edges from the graph
881
-
882
- Args:
883
- edges: List of edges to be deleted, each edge is a (source, target) tuple
884
- """
885
- if not edges:
886
- return
887
-
888
- try:
889
- for source, target in edges:
890
- # Check if the edge exists before attempting to delete
891
- if await self.has_edge(source, target):
892
- # Delete the edge using a SQL query that matches both source and target
893
- delete_edge_sql = """
894
- DELETE FROM LIGHTRAG_GRAPH_EDGES
895
- WHERE workspace = :workspace
896
- AND source_name = :source_name
897
- AND target_name = :target_name
898
- """
899
- params = {
900
- "workspace": self.db.workspace,
901
- "source_name": source,
902
- "target_name": target,
903
- }
904
- await self.db.execute(delete_edge_sql, params)
905
-
906
- logger.info(f"Successfully deleted {len(edges)} edges from the graph")
907
- except Exception as e:
908
- logger.error(f"Error during batch edge deletion: {e}")
909
- raise
910
-
911
- async def get_all_labels(self) -> list[str]:
912
- """Get all unique entity types (labels) in the graph
913
-
914
- Returns:
915
- List of unique entity types/labels
916
- """
917
- try:
918
- SQL = """
919
- SELECT DISTINCT entity_type
920
- FROM LIGHTRAG_GRAPH_NODES
921
- WHERE workspace = :workspace
922
- ORDER BY entity_type
923
- """
924
- params = {"workspace": self.db.workspace}
925
- results = await self.db.query(SQL, params, multirows=True)
926
-
927
- if results:
928
- labels = [row["entity_type"] for row in results]
929
- return labels
930
- else:
931
- return []
932
- except Exception as e:
933
- logger.error(f"Error retrieving entity types: {e}")
934
- return []
935
-
936
- async def get_knowledge_graph(
937
- self, node_label: str, max_depth: int = 5
938
- ) -> KnowledgeGraph:
939
- """Retrieve a connected subgraph starting from nodes matching the given label
940
-
941
- Maximum number of nodes is constrained by MAX_GRAPH_NODES environment variable.
942
- Prioritizes nodes by:
943
- 1. Nodes matching the specified label
944
- 2. Nodes directly connected to matching nodes
945
- 3. Node degree (number of connections)
946
-
947
- Args:
948
- node_label: Label to match for starting nodes (use "*" for all nodes)
949
- max_depth: Maximum depth of traversal from starting nodes
950
-
951
- Returns:
952
- KnowledgeGraph object containing nodes and edges
953
- """
954
- result = KnowledgeGraph()
955
-
956
- try:
957
- # Define maximum number of nodes to return
958
- max_graph_nodes = int(os.environ.get("MAX_GRAPH_NODES", 1000))
959
-
960
- if node_label == "*":
961
- # For "*" label, get all nodes up to the limit
962
- nodes_sql = """
963
- SELECT name, entity_type, description, source_chunk_id
964
- FROM LIGHTRAG_GRAPH_NODES
965
- WHERE workspace = :workspace
966
- ORDER BY id
967
- FETCH FIRST :limit ROWS ONLY
968
- """
969
- nodes_params = {
970
- "workspace": self.db.workspace,
971
- "limit": max_graph_nodes,
972
- }
973
- nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
974
- else:
975
- # For specific label, find matching nodes and related nodes
976
- nodes_sql = """
977
- WITH matching_nodes AS (
978
- SELECT name
979
- FROM LIGHTRAG_GRAPH_NODES
980
- WHERE workspace = :workspace
981
- AND (name LIKE '%' || :node_label || '%' OR entity_type LIKE '%' || :node_label || '%')
982
- )
983
- SELECT n.name, n.entity_type, n.description, n.source_chunk_id,
984
- CASE
985
- WHEN n.name IN (SELECT name FROM matching_nodes) THEN 2
986
- WHEN EXISTS (
987
- SELECT 1 FROM LIGHTRAG_GRAPH_EDGES e
988
- WHERE workspace = :workspace
989
- AND ((e.source_name = n.name AND e.target_name IN (SELECT name FROM matching_nodes))
990
- OR (e.target_name = n.name AND e.source_name IN (SELECT name FROM matching_nodes)))
991
- ) THEN 1
992
- ELSE 0
993
- END AS priority,
994
- (SELECT COUNT(*) FROM LIGHTRAG_GRAPH_EDGES e
995
- WHERE workspace = :workspace
996
- AND (e.source_name = n.name OR e.target_name = n.name)) AS degree
997
- FROM LIGHTRAG_GRAPH_NODES n
998
- WHERE workspace = :workspace
999
- ORDER BY priority DESC, degree DESC
1000
- FETCH FIRST :limit ROWS ONLY
1001
- """
1002
- nodes_params = {
1003
- "workspace": self.db.workspace,
1004
- "node_label": node_label,
1005
- "limit": max_graph_nodes,
1006
- }
1007
- nodes = await self.db.query(nodes_sql, nodes_params, multirows=True)
1008
-
1009
- if not nodes:
1010
- logger.warning(f"No nodes found matching '{node_label}'")
1011
- return result
1012
-
1013
- # Create mapping of node IDs to be used to filter edges
1014
- node_names = [node["name"] for node in nodes]
1015
-
1016
- # Add nodes to result
1017
- seen_nodes = set()
1018
- for node in nodes:
1019
- node_id = node["name"]
1020
- if node_id in seen_nodes:
1021
- continue
1022
-
1023
- # Create node properties dictionary
1024
- properties = {
1025
- "entity_type": node["entity_type"],
1026
- "description": node["description"] or "",
1027
- "source_id": node["source_chunk_id"] or "",
1028
- }
1029
-
1030
- # Add node to result
1031
- result.nodes.append(
1032
- KnowledgeGraphNode(
1033
- id=node_id, labels=[node["entity_type"]], properties=properties
1034
- )
1035
- )
1036
- seen_nodes.add(node_id)
1037
-
1038
- # Get edges between these nodes
1039
- edges_sql = """
1040
- SELECT source_name, target_name, weight, keywords, description, source_chunk_id
1041
- FROM LIGHTRAG_GRAPH_EDGES
1042
- WHERE workspace = :workspace
1043
- AND source_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
1044
- AND target_name IN (SELECT COLUMN_VALUE FROM TABLE(CAST(:node_names AS SYS.ODCIVARCHAR2LIST)))
1045
- ORDER BY id
1046
- """
1047
- edges_params = {"workspace": self.db.workspace, "node_names": node_names}
1048
- edges = await self.db.query(edges_sql, edges_params, multirows=True)
1049
-
1050
- # Add edges to result
1051
- seen_edges = set()
1052
- for edge in edges:
1053
- source = edge["source_name"]
1054
- target = edge["target_name"]
1055
- edge_id = f"{source}-{target}"
1056
-
1057
- if edge_id in seen_edges:
1058
- continue
1059
-
1060
- # Create edge properties dictionary
1061
- properties = {
1062
- "weight": edge["weight"] or 0.0,
1063
- "keywords": edge["keywords"] or "",
1064
- "description": edge["description"] or "",
1065
- "source_id": edge["source_chunk_id"] or "",
1066
- }
1067
-
1068
- # Add edge to result
1069
- result.edges.append(
1070
- KnowledgeGraphEdge(
1071
- id=edge_id,
1072
- type="RELATED",
1073
- source=source,
1074
- target=target,
1075
- properties=properties,
1076
- )
1077
- )
1078
- seen_edges.add(edge_id)
1079
-
1080
- logger.info(
1081
- f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
1082
- )
1083
-
1084
- except Exception as e:
1085
- logger.error(f"Error retrieving knowledge graph: {e}")
1086
-
1087
- return result
1088
-
1089
-
1090
- N_T = {
1091
- NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
1092
- NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
1093
- NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
1094
- NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
1095
- NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
1096
- }
1097
-
1098
-
1099
- def namespace_to_table_name(namespace: str) -> str:
1100
- for k, v in N_T.items():
1101
- if is_namespace(namespace, k):
1102
- return v
1103
-
1104
-
1105
- TABLES = {
1106
- "LIGHTRAG_DOC_FULL": {
1107
- "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
1108
- id varchar(256),
1109
- workspace varchar(1024),
1110
- doc_name varchar(1024),
1111
- content CLOB,
1112
- meta JSON,
1113
- content_summary varchar(1024),
1114
- content_length NUMBER,
1115
- status varchar(256),
1116
- chunks_count NUMBER,
1117
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1118
- updatetime TIMESTAMP DEFAULT NULL,
1119
- error varchar(4096)
1120
- )"""
1121
- },
1122
- "LIGHTRAG_DOC_CHUNKS": {
1123
- "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
1124
- id varchar(256),
1125
- workspace varchar(1024),
1126
- full_doc_id varchar(256),
1127
- status varchar(256),
1128
- chunk_order_index NUMBER,
1129
- tokens NUMBER,
1130
- content CLOB,
1131
- content_vector VECTOR,
1132
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1133
- updatetime TIMESTAMP DEFAULT NULL
1134
- )"""
1135
- },
1136
- "LIGHTRAG_GRAPH_NODES": {
1137
- "ddl": """CREATE TABLE LIGHTRAG_GRAPH_NODES (
1138
- id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
1139
- workspace varchar(1024),
1140
- name varchar(2048),
1141
- entity_type varchar(1024),
1142
- description CLOB,
1143
- source_chunk_id varchar(256),
1144
- content CLOB,
1145
- content_vector VECTOR,
1146
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1147
- updatetime TIMESTAMP DEFAULT NULL
1148
- )"""
1149
- },
1150
- "LIGHTRAG_GRAPH_EDGES": {
1151
- "ddl": """CREATE TABLE LIGHTRAG_GRAPH_EDGES (
1152
- id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
1153
- workspace varchar(1024),
1154
- source_name varchar(2048),
1155
- target_name varchar(2048),
1156
- weight NUMBER,
1157
- keywords CLOB,
1158
- description CLOB,
1159
- source_chunk_id varchar(256),
1160
- content CLOB,
1161
- content_vector VECTOR,
1162
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1163
- updatetime TIMESTAMP DEFAULT NULL
1164
- )"""
1165
- },
1166
- "LIGHTRAG_LLM_CACHE": {
1167
- "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
1168
- id varchar(256) PRIMARY KEY,
1169
- workspace varchar(1024),
1170
- cache_mode varchar(256),
1171
- model_name varchar(256),
1172
- original_prompt clob,
1173
- return_value clob,
1174
- embedding CLOB,
1175
- embedding_shape NUMBER,
1176
- embedding_min NUMBER,
1177
- embedding_max NUMBER,
1178
- createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1179
- updatetime TIMESTAMP DEFAULT NULL
1180
- )"""
1181
- },
1182
- "LIGHTRAG_GRAPH": {
1183
- "ddl": """CREATE OR REPLACE PROPERTY GRAPH lightrag_graph
1184
- VERTEX TABLES (
1185
- lightrag_graph_nodes KEY (id)
1186
- LABEL entity
1187
- PROPERTIES (id,workspace,name) -- ,entity_type,description,source_chunk_id)
1188
- )
1189
- EDGE TABLES (
1190
- lightrag_graph_edges KEY (id)
1191
- SOURCE KEY (source_name) REFERENCES lightrag_graph_nodes(name)
1192
- DESTINATION KEY (target_name) REFERENCES lightrag_graph_nodes(name)
1193
- LABEL has_relation
1194
- PROPERTIES (id,workspace,source_name,target_name) -- ,weight, keywords,description,source_chunk_id)
1195
- ) OPTIONS(ALLOW MIXED PROPERTY TYPES)"""
1196
- },
1197
- }
1198
-
1199
-
1200
- SQL_TEMPLATES = {
1201
- # SQL for KVStorage
1202
- "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
1203
- "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
1204
- "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
1205
- FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
1206
- "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
1207
- FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
1208
- "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
1209
- FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
1210
- "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
1211
- "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
1212
- "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
1213
- "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
1214
- "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
1215
- "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
1216
- "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
1217
- "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
1218
- USING DUAL
1219
- ON (a.id = :id and a.workspace = :workspace)
1220
- WHEN NOT MATCHED THEN
1221
- INSERT(id,content,workspace) values(:id,:content,:workspace)""",
1222
- "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
1223
- USING DUAL
1224
- ON (id = :id and workspace = :workspace)
1225
- WHEN NOT MATCHED THEN INSERT
1226
- (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
1227
- values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
1228
- "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
1229
- USING DUAL
1230
- ON (a.id = :id)
1231
- WHEN NOT MATCHED THEN
1232
- INSERT (workspace,id,original_prompt,return_value,cache_mode)
1233
- VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode)
1234
- WHEN MATCHED THEN UPDATE
1235
- SET original_prompt = :original_prompt,
1236
- return_value = :return_value,
1237
- cache_mode = :cache_mode,
1238
- updatetime = SYSDATE""",
1239
- # SQL for VectorStorage
1240
- "entities": """SELECT name as entity_name FROM
1241
- (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
1242
- FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace)
1243
- WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
1244
- "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
1245
- (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
1246
- FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace)
1247
- WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
1248
- "chunks": """SELECT id FROM
1249
- (SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
1250
- FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace)
1251
- WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
1252
- # SQL for GraphStorage
1253
- "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
1254
- MATCH (a)
1255
- WHERE a.workspace=:workspace AND a.name=:node_id
1256
- COLUMNS (a.name))""",
1257
- "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
1258
- MATCH (a) -[e]-> (b)
1259
- WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
1260
- AND a.name=:source_node_id AND b.name=:target_node_id
1261
- COLUMNS (e.source_name,e.target_name) )""",
1262
- "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
1263
- MATCH (a)-[e]->(b)
1264
- WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
1265
- AND a.name=:node_id or b.name = :node_id
1266
- COLUMNS (a.name))""",
1267
- "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
1268
- FROM GRAPH_TABLE (lightrag_graph
1269
- MATCH (a)
1270
- WHERE a.workspace=:workspace AND a.name=:node_id
1271
- COLUMNS (a.name)
1272
- ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
1273
- WHERE t2.workspace=:workspace""",
1274
- "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
1275
- NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
1276
- FROM GRAPH_TABLE (lightrag_graph
1277
- MATCH (a)-[e]->(b)
1278
- WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
1279
- AND a.name=:source_node_id and b.name = :target_node_id
1280
- COLUMNS (e.id,a.name as source_id)
1281
- ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
1282
- "get_node_edges": """SELECT source_name,target_name
1283
- FROM GRAPH_TABLE (lightrag_graph
1284
- MATCH (a)-[e]->(b)
1285
- WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
1286
- AND a.name=:source_node_id
1287
- COLUMNS (a.name as source_name,b.name as target_name))""",
1288
- "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
1289
- USING DUAL
1290
- ON (a.workspace=:workspace and a.name=:name)
1291
- WHEN NOT MATCHED THEN
1292
- INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
1293
- values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector)
1294
- WHEN MATCHED THEN
1295
- UPDATE SET
1296
- entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
1297
- "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
1298
- USING DUAL
1299
- ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name)
1300
- WHEN NOT MATCHED THEN
1301
- INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
1302
- values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector)
1303
- WHEN MATCHED THEN
1304
- UPDATE SET
1305
- weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
1306
- "get_all_nodes": """WITH t0 AS (
1307
- SELECT name AS id, entity_type AS label, entity_type, description,
1308
- '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
1309
- FROM lightrag_graph_nodes
1310
- WHERE workspace = :workspace
1311
- ORDER BY createtime DESC fetch first :limit rows only
1312
- ), t1 AS (
1313
- SELECT t0.id, source_chunk_id
1314
- FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
1315
- ), t2 AS (
1316
- SELECT t1.id, LISTAGG(t2.content, '\n') content
1317
- FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
1318
- GROUP BY t1.id
1319
- )
1320
- SELECT t0.id, label, entity_type, description, t2.content
1321
- FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
1322
- "get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
1323
- t1.weight,t1.DESCRIPTION,t2.content
1324
- FROM LIGHTRAG_GRAPH_EDGES t1
1325
- LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
1326
- WHERE t1.workspace=:workspace
1327
- order by t1.CREATETIME DESC
1328
- fetch first :limit rows only""",
1329
- "get_statistics": """select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
1330
- count(distinct CASE WHEN type='edge' THEN id END) as edges_count
1331
- FROM (
1332
- select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
1333
- MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
1334
- UNION
1335
- select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
1336
- MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
1337
- )""",
1338
- # SQL for deletion
1339
- "delete_vectors": "DELETE FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace AND id IN ({ids})",
1340
- "delete_entity": "DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace AND name=:entity_name",
1341
- "delete_entity_relations": "DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace AND (source_name=:entity_name OR target_name=:entity_name)",
1342
- "delete_node": """DELETE FROM GRAPH_TABLE (lightrag_graph
1343
- MATCH (a)
1344
- WHERE a.workspace=:workspace AND a.name=:node_id
1345
- ACTION DELETE a)""",
1346
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightrag/kg/postgres_impl.py CHANGED
@@ -9,7 +9,6 @@ import configparser
9
 
10
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
11
 
12
- import sys
13
  from tenacity import (
14
  retry,
15
  retry_if_exception_type,
@@ -28,11 +27,6 @@ from ..base import (
28
  from ..namespace import NameSpace, is_namespace
29
  from ..utils import logger
30
 
31
- if sys.platform.startswith("win"):
32
- import asyncio.windows_events
33
-
34
- asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
35
-
36
  import pipmaster as pm
37
 
38
  if not pm.is_installed("asyncpg"):
@@ -41,6 +35,9 @@ if not pm.is_installed("asyncpg"):
41
  import asyncpg # type: ignore
42
  from asyncpg import Pool # type: ignore
43
 
 
 
 
44
 
45
  class PostgreSQLDB:
46
  def __init__(self, config: dict[str, Any], **kwargs: Any):
@@ -118,6 +115,25 @@ class PostgreSQLDB:
118
  )
119
  raise e
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  async def query(
122
  self,
123
  sql: str,
@@ -254,8 +270,6 @@ class PGKVStorage(BaseKVStorage):
254
  db: PostgreSQLDB = field(default=None)
255
 
256
  def __post_init__(self):
257
- namespace_prefix = self.global_config.get("namespace_prefix")
258
- self.base_namespace = self.namespace.replace(namespace_prefix, "")
259
  self._max_batch_size = self.global_config["embedding_batch_num"]
260
 
261
  async def initialize(self):
@@ -271,7 +285,7 @@ class PGKVStorage(BaseKVStorage):
271
 
272
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
273
  """Get doc_full data by id."""
274
- sql = SQL_TEMPLATES["get_by_id_" + self.base_namespace]
275
  params = {"workspace": self.db.workspace, "id": id}
276
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
277
  array_res = await self.db.query(sql, params, multirows=True)
@@ -285,7 +299,7 @@ class PGKVStorage(BaseKVStorage):
285
 
286
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
287
  """Specifically for llm_response_cache."""
288
- sql = SQL_TEMPLATES["get_by_mode_id_" + self.base_namespace]
289
  params = {"workspace": self.db.workspace, mode: mode, "id": id}
290
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
291
  array_res = await self.db.query(sql, params, multirows=True)
@@ -299,7 +313,7 @@ class PGKVStorage(BaseKVStorage):
299
  # Query by id
300
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
301
  """Get doc_chunks data by id"""
302
- sql = SQL_TEMPLATES["get_by_ids_" + self.base_namespace].format(
303
  ids=",".join([f"'{id}'" for id in ids])
304
  )
305
  params = {"workspace": self.db.workspace}
@@ -320,7 +334,7 @@ class PGKVStorage(BaseKVStorage):
320
 
321
  async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
322
  """Specifically for llm_response_cache."""
323
- SQL = SQL_TEMPLATES["get_by_status_" + self.base_namespace]
324
  params = {"workspace": self.db.workspace, "status": status}
325
  return await self.db.query(SQL, params, multirows=True)
326
 
@@ -380,10 +394,85 @@ class PGKVStorage(BaseKVStorage):
380
  # PG handles persistence automatically
381
  pass
382
 
383
- async def drop(self) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  """Drop the storage"""
385
- drop_sql = SQL_TEMPLATES["drop_all"]
386
- await self.db.execute(drop_sql)
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
 
389
  @final
@@ -393,8 +482,6 @@ class PGVectorStorage(BaseVectorStorage):
393
 
394
  def __post_init__(self):
395
  self._max_batch_size = self.global_config["embedding_batch_num"]
396
- namespace_prefix = self.global_config.get("namespace_prefix")
397
- self.base_namespace = self.namespace.replace(namespace_prefix, "")
398
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
399
  cosine_threshold = config.get("cosine_better_than_threshold")
400
  if cosine_threshold is None:
@@ -523,7 +610,7 @@ class PGVectorStorage(BaseVectorStorage):
523
  else:
524
  formatted_ids = "NULL"
525
 
526
- sql = SQL_TEMPLATES[self.base_namespace].format(
527
  embedding_string=embedding_string, doc_ids=formatted_ids
528
  )
529
  params = {
@@ -552,13 +639,12 @@ class PGVectorStorage(BaseVectorStorage):
552
  logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
553
  return
554
 
555
- ids_list = ",".join([f"'{id}'" for id in ids])
556
- delete_sql = (
557
- f"DELETE FROM {table_name} WHERE workspace=$1 AND id IN ({ids_list})"
558
- )
559
 
560
  try:
561
- await self.db.execute(delete_sql, {"workspace": self.db.workspace})
 
 
562
  logger.debug(
563
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
564
  )
@@ -690,6 +776,24 @@ class PGVectorStorage(BaseVectorStorage):
690
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
691
  return []
692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
 
694
  @final
695
  @dataclass
@@ -810,6 +914,35 @@ class PGDocStatusStorage(DocStatusStorage):
810
  # PG handles persistence automatically
811
  pass
812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
814
  """Update or insert document status
815
 
@@ -846,10 +979,23 @@ class PGDocStatusStorage(DocStatusStorage):
846
  },
847
  )
848
 
849
- async def drop(self) -> None:
850
  """Drop the storage"""
851
- drop_sql = SQL_TEMPLATES["drop_doc_full"]
852
- await self.db.execute(drop_sql)
 
 
 
 
 
 
 
 
 
 
 
 
 
853
 
854
 
855
  class PGGraphQueryException(Exception):
@@ -937,31 +1083,11 @@ class PGGraphStorage(BaseGraphStorage):
937
  if v.startswith("[") and v.endswith("]"):
938
  if "::vertex" in v:
939
  v = v.replace("::vertex", "")
940
- vertexes = json.loads(v)
941
- dl = []
942
- for vertex in vertexes:
943
- prop = vertex.get("properties")
944
- if not prop:
945
- prop = {}
946
- prop["label"] = PGGraphStorage._decode_graph_label(
947
- prop["node_id"]
948
- )
949
- dl.append(prop)
950
- d[k] = dl
951
 
952
  elif "::edge" in v:
953
  v = v.replace("::edge", "")
954
- edges = json.loads(v)
955
- dl = []
956
- for edge in edges:
957
- dl.append(
958
- (
959
- vertices[edge["start_id"]],
960
- edge["label"],
961
- vertices[edge["end_id"]],
962
- )
963
- )
964
- d[k] = dl
965
  else:
966
  print("WARNING: unsupported type")
967
  continue
@@ -970,32 +1096,19 @@ class PGGraphStorage(BaseGraphStorage):
970
  dtype = v.split("::")[-1]
971
  v = v.split("::")[0]
972
  if dtype == "vertex":
973
- vertex = json.loads(v)
974
- field = vertex.get("properties")
975
- if not field:
976
- field = {}
977
- field["label"] = PGGraphStorage._decode_graph_label(
978
- field["node_id"]
979
- )
980
- d[k] = field
981
- # convert edge from id-label->id by replacing id with node information
982
- # we only do this if the vertex was also returned in the query
983
- # this is an attempt to be consistent with neo4j implementation
984
  elif dtype == "edge":
985
- edge = json.loads(v)
986
- d[k] = (
987
- vertices.get(edge["start_id"], {}),
988
- edge[
989
- "label"
990
- ], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
991
- vertices.get(edge["end_id"], {}),
992
- )
993
  else:
994
- d[k] = (
995
- json.loads(v)
996
- if isinstance(v, str) and ("{" in v or "[" in v)
997
- else v
998
- )
 
 
 
 
999
 
1000
  return d
1001
 
@@ -1025,56 +1138,6 @@ class PGGraphStorage(BaseGraphStorage):
1025
  )
1026
  return "{" + ", ".join(props) + "}"
1027
 
1028
- @staticmethod
1029
- def _encode_graph_label(label: str) -> str:
1030
- """
1031
- Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
1032
-
1033
- Args:
1034
- label (str): the original label
1035
-
1036
- Returns:
1037
- str: the encoded label
1038
- """
1039
- return "x" + label.encode().hex()
1040
-
1041
- @staticmethod
1042
- def _decode_graph_label(encoded_label: str) -> str:
1043
- """
1044
- Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
1045
-
1046
- Args:
1047
- encoded_label (str): the encoded label
1048
-
1049
- Returns:
1050
- str: the decoded label
1051
- """
1052
- return bytes.fromhex(encoded_label.removeprefix("x")).decode()
1053
-
1054
- @staticmethod
1055
- def _get_col_name(field: str, idx: int) -> str:
1056
- """
1057
- Convert a cypher return field to a pgsql select field
1058
- If possible keep the cypher column name, but create a generic name if necessary
1059
-
1060
- Args:
1061
- field (str): a return field from a cypher query to be formatted for pgsql
1062
- idx (int): the position of the field in the return statement
1063
-
1064
- Returns:
1065
- str: the field to be used in the pgsql select statement
1066
- """
1067
- # remove white space
1068
- field = field.strip()
1069
- # if an alias is provided for the field, use it
1070
- if " as " in field:
1071
- return field.split(" as ")[-1].strip()
1072
- # if the return value is an unnamed primitive, give it a generic name
1073
- if field.isnumeric() or field in ("true", "false", "null"):
1074
- return f"column_{idx}"
1075
- # otherwise return the value stripping out some common special chars
1076
- return field.replace("(", "_").replace(")", "")
1077
-
1078
  async def _query(
1079
  self,
1080
  query: str,
@@ -1125,10 +1188,10 @@ class PGGraphStorage(BaseGraphStorage):
1125
  return result
1126
 
1127
  async def has_node(self, node_id: str) -> bool:
1128
- entity_name_label = self._encode_graph_label(node_id.strip('"'))
1129
 
1130
  query = """SELECT * FROM cypher('%s', $$
1131
- MATCH (n:Entity {node_id: "%s"})
1132
  RETURN count(n) > 0 AS node_exists
1133
  $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
1134
 
@@ -1137,11 +1200,11 @@ class PGGraphStorage(BaseGraphStorage):
1137
  return single_result["node_exists"]
1138
 
1139
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
1140
- src_label = self._encode_graph_label(source_node_id.strip('"'))
1141
- tgt_label = self._encode_graph_label(target_node_id.strip('"'))
1142
 
1143
  query = """SELECT * FROM cypher('%s', $$
1144
- MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
1145
  RETURN COUNT(r) > 0 AS edge_exists
1146
  $$) AS (edge_exists bool)""" % (
1147
  self.graph_name,
@@ -1154,30 +1217,31 @@ class PGGraphStorage(BaseGraphStorage):
1154
  return single_result["edge_exists"]
1155
 
1156
  async def get_node(self, node_id: str) -> dict[str, str] | None:
1157
- label = self._encode_graph_label(node_id.strip('"'))
 
 
1158
  query = """SELECT * FROM cypher('%s', $$
1159
- MATCH (n:Entity {node_id: "%s"})
1160
  RETURN n
1161
  $$) AS (n agtype)""" % (self.graph_name, label)
1162
  record = await self._query(query)
1163
  if record:
1164
  node = record[0]
1165
- node_dict = node["n"]
1166
 
1167
  return node_dict
1168
  return None
1169
 
1170
  async def node_degree(self, node_id: str) -> int:
1171
- label = self._encode_graph_label(node_id.strip('"'))
1172
 
1173
  query = """SELECT * FROM cypher('%s', $$
1174
- MATCH (n:Entity {node_id: "%s"})-[]->(x)
1175
  RETURN count(x) AS total_edge_count
1176
  $$) AS (total_edge_count integer)""" % (self.graph_name, label)
1177
  record = (await self._query(query))[0]
1178
  if record:
1179
  edge_count = int(record["total_edge_count"])
1180
-
1181
  return edge_count
1182
 
1183
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@@ -1195,11 +1259,13 @@ class PGGraphStorage(BaseGraphStorage):
1195
  async def get_edge(
1196
  self, source_node_id: str, target_node_id: str
1197
  ) -> dict[str, str] | None:
1198
- src_label = self._encode_graph_label(source_node_id.strip('"'))
1199
- tgt_label = self._encode_graph_label(target_node_id.strip('"'))
 
 
1200
 
1201
  query = """SELECT * FROM cypher('%s', $$
1202
- MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
1203
  RETURN properties(r) as edge_properties
1204
  LIMIT 1
1205
  $$) AS (edge_properties agtype)""" % (
@@ -1218,11 +1284,11 @@ class PGGraphStorage(BaseGraphStorage):
1218
  Retrieves all edges (relationships) for a particular node identified by its label.
1219
  :return: list of dictionaries containing edge information
1220
  """
1221
- label = self._encode_graph_label(source_node_id.strip('"'))
1222
 
1223
  query = """SELECT * FROM cypher('%s', $$
1224
- MATCH (n:Entity {node_id: "%s"})
1225
- OPTIONAL MATCH (n)-[]-(connected)
1226
  RETURN n, connected
1227
  $$) AS (n agtype, connected agtype)""" % (
1228
  self.graph_name,
@@ -1235,24 +1301,17 @@ class PGGraphStorage(BaseGraphStorage):
1235
  source_node = record["n"] if record["n"] else None
1236
  connected_node = record["connected"] if record["connected"] else None
1237
 
1238
- source_label = (
1239
- source_node["node_id"]
1240
- if source_node and source_node["node_id"]
1241
- else None
1242
- )
1243
- target_label = (
1244
- connected_node["node_id"]
1245
- if connected_node and connected_node["node_id"]
1246
- else None
1247
- )
1248
 
1249
- if source_label and target_label:
1250
- edges.append(
1251
- (
1252
- self._decode_graph_label(source_label),
1253
- self._decode_graph_label(target_label),
1254
- )
1255
- )
1256
 
1257
  return edges
1258
 
@@ -1262,24 +1321,36 @@ class PGGraphStorage(BaseGraphStorage):
1262
  retry=retry_if_exception_type((PGGraphQueryException,)),
1263
  )
1264
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
1265
- label = self._encode_graph_label(node_id.strip('"'))
1266
- properties = node_data
 
 
 
 
 
 
 
 
 
 
 
 
1267
 
1268
  query = """SELECT * FROM cypher('%s', $$
1269
- MERGE (n:Entity {node_id: "%s"})
1270
  SET n += %s
1271
  RETURN n
1272
  $$) AS (n agtype)""" % (
1273
  self.graph_name,
1274
  label,
1275
- self._format_properties(properties),
1276
  )
1277
 
1278
  try:
1279
  await self._query(query, readonly=False, upsert=True)
1280
 
1281
- except Exception as e:
1282
- logger.error("POSTGRES, Error during upsert: {%s}", e)
1283
  raise
1284
 
1285
  @retry(
@@ -1298,14 +1369,14 @@ class PGGraphStorage(BaseGraphStorage):
1298
  target_node_id (str): Label of the target node (used as identifier)
1299
  edge_data (dict): dictionary of properties to set on the edge
1300
  """
1301
- src_label = self._encode_graph_label(source_node_id.strip('"'))
1302
- tgt_label = self._encode_graph_label(target_node_id.strip('"'))
1303
- edge_properties = edge_data
1304
 
1305
  query = """SELECT * FROM cypher('%s', $$
1306
- MATCH (source:Entity {node_id: "%s"})
1307
  WITH source
1308
- MATCH (target:Entity {node_id: "%s"})
1309
  MERGE (source)-[r:DIRECTED]->(target)
1310
  SET r += %s
1311
  RETURN r
@@ -1313,14 +1384,16 @@ class PGGraphStorage(BaseGraphStorage):
1313
  self.graph_name,
1314
  src_label,
1315
  tgt_label,
1316
- self._format_properties(edge_properties),
1317
  )
1318
 
1319
  try:
1320
  await self._query(query, readonly=False, upsert=True)
1321
 
1322
- except Exception as e:
1323
- logger.error("Error during edge upsert: {%s}", e)
 
 
1324
  raise
1325
 
1326
  async def _node2vec_embed(self):
@@ -1333,10 +1406,10 @@ class PGGraphStorage(BaseGraphStorage):
1333
  Args:
1334
  node_id (str): The ID of the node to delete.
1335
  """
1336
- label = self._encode_graph_label(node_id.strip('"'))
1337
 
1338
  query = """SELECT * FROM cypher('%s', $$
1339
- MATCH (n:Entity {node_id: "%s"})
1340
  DETACH DELETE n
1341
  $$) AS (n agtype)""" % (self.graph_name, label)
1342
 
@@ -1353,14 +1426,12 @@ class PGGraphStorage(BaseGraphStorage):
1353
  Args:
1354
  node_ids (list[str]): A list of node IDs to remove.
1355
  """
1356
- encoded_node_ids = [
1357
- self._encode_graph_label(node_id.strip('"')) for node_id in node_ids
1358
- ]
1359
- node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
1360
 
1361
  query = """SELECT * FROM cypher('%s', $$
1362
- MATCH (n:Entity)
1363
- WHERE n.node_id IN [%s]
1364
  DETACH DELETE n
1365
  $$) AS (n agtype)""" % (self.graph_name, node_id_list)
1366
 
@@ -1377,26 +1448,21 @@ class PGGraphStorage(BaseGraphStorage):
1377
  Args:
1378
  edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
1379
  """
1380
- encoded_edges = [
1381
- (
1382
- self._encode_graph_label(src.strip('"')),
1383
- self._encode_graph_label(tgt.strip('"')),
1384
- )
1385
- for src, tgt in edges
1386
- ]
1387
- edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
1388
 
1389
- query = """SELECT * FROM cypher('%s', $$
1390
- MATCH (a:Entity)-[r]->(b:Entity)
1391
- WHERE [a.node_id, b.node_id] IN [%s]
1392
- DELETE r
1393
- $$) AS (r agtype)""" % (self.graph_name, edge_list)
1394
 
1395
- try:
1396
- await self._query(query, readonly=False)
1397
- except Exception as e:
1398
- logger.error("Error during edge removal: {%s}", e)
1399
- raise
 
1400
 
1401
  async def get_all_labels(self) -> list[str]:
1402
  """
@@ -1407,15 +1473,16 @@ class PGGraphStorage(BaseGraphStorage):
1407
  """
1408
  query = (
1409
  """SELECT * FROM cypher('%s', $$
1410
- MATCH (n:Entity)
1411
- RETURN DISTINCT n.node_id AS label
 
 
1412
  $$) AS (label text)"""
1413
  % self.graph_name
1414
  )
1415
 
1416
  results = await self._query(query)
1417
- labels = [self._decode_graph_label(result["label"]) for result in results]
1418
-
1419
  return labels
1420
 
1421
  async def embed_nodes(
@@ -1437,105 +1504,135 @@ class PGGraphStorage(BaseGraphStorage):
1437
  return await embed_func()
1438
 
1439
  async def get_knowledge_graph(
1440
- self, node_label: str, max_depth: int = 5
 
 
 
1441
  ) -> KnowledgeGraph:
1442
  """
1443
- Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.
1444
 
1445
  Args:
1446
- node_label (str): The label of the node to start from. If "*", the entire graph is returned.
1447
- max_depth (int): The maximum depth to traverse from the starting node.
 
1448
 
1449
  Returns:
1450
- KnowledgeGraph: The retrieved subgraph.
 
1451
  """
1452
- MAX_GRAPH_NODES = 1000
1453
-
1454
- # Build the query based on whether we want the full graph or a specific subgraph.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1455
  if node_label == "*":
1456
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1457
- MATCH (n:Entity)
1458
- OPTIONAL MATCH (n)-[r]->(m:Entity)
1459
- RETURN n, r, m
1460
- LIMIT {MAX_GRAPH_NODES}
1461
- $$) AS (n agtype, r agtype, m agtype)"""
1462
  else:
1463
- encoded_label = self._encode_graph_label(node_label.strip('"'))
1464
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1465
- MATCH (n:Entity {{node_id: "{encoded_label}"}})
1466
- OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
1467
- RETURN nodes(p) AS nodes, relationships(p) AS relationships
1468
- LIMIT {MAX_GRAPH_NODES}
1469
- $$) AS (nodes agtype, relationships agtype)"""
1470
 
1471
  results = await self._query(query)
1472
 
1473
- nodes = {}
1474
- edges = []
1475
- unique_edge_ids = set()
1476
-
1477
- def add_node(node_data: dict):
1478
- node_id = self._decode_graph_label(node_data["node_id"])
1479
- if node_id not in nodes:
1480
- nodes[node_id] = node_data
1481
-
1482
- def add_edge(edge_data: list):
1483
- src_id = self._decode_graph_label(edge_data[0]["node_id"])
1484
- tgt_id = self._decode_graph_label(edge_data[2]["node_id"])
1485
- edge_key = f"{src_id},{tgt_id}"
1486
- if edge_key not in unique_edge_ids:
1487
- unique_edge_ids.add(edge_key)
1488
- edges.append(
1489
- (
1490
- edge_key,
1491
- src_id,
1492
- tgt_id,
1493
- {"source": edge_data[0], "target": edge_data[2]},
1494
  )
1495
- )
 
 
 
 
 
 
 
 
 
 
1496
 
1497
- # Process the query results.
1498
- if node_label == "*":
1499
- for result in results:
1500
- if result.get("n"):
1501
- add_node(result["n"])
1502
- if result.get("m"):
1503
- add_node(result["m"])
1504
- if result.get("r"):
1505
- add_edge(result["r"])
1506
- else:
1507
- for result in results:
1508
- for node in result.get("nodes", []):
1509
- add_node(node)
1510
- for edge in result.get("relationships", []):
1511
- add_edge(edge)
 
 
 
 
 
 
 
 
 
1512
 
1513
- # Construct and return the KnowledgeGraph.
1514
  kg = KnowledgeGraph(
1515
- nodes=[
1516
- KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data)
1517
- for node_id, node_data in nodes.items()
1518
- ],
1519
- edges=[
1520
- KnowledgeGraphEdge(
1521
- id=edge_id,
1522
- type="DIRECTED",
1523
- source=src,
1524
- target=tgt,
1525
- properties=props,
1526
- )
1527
- for edge_id, src, tgt, props in edges
1528
- ],
1529
  )
1530
 
 
 
 
1531
  return kg
1532
 
1533
- async def drop(self) -> None:
1534
  """Drop the storage"""
1535
- drop_sql = SQL_TEMPLATES["drop_vdb_entity"]
1536
- await self.db.execute(drop_sql)
1537
- drop_sql = SQL_TEMPLATES["drop_vdb_relation"]
1538
- await self.db.execute(drop_sql)
 
 
 
 
 
 
 
1539
 
1540
 
1541
  NAMESPACE_TABLE_MAP = {
@@ -1648,7 +1745,7 @@ SQL_TEMPLATES = {
1648
  FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
1649
  """,
1650
  "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
1651
- chunk_order_index, full_doc_id
1652
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
1653
  """,
1654
  "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
@@ -1661,7 +1758,7 @@ SQL_TEMPLATES = {
1661
  FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
1662
  """,
1663
  "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
1664
- chunk_order_index, full_doc_id
1665
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
1666
  """,
1667
  "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
@@ -1693,6 +1790,7 @@ SQL_TEMPLATES = {
1693
  file_path=EXCLUDED.file_path,
1694
  update_time = CURRENT_TIMESTAMP
1695
  """,
 
1696
  "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
1697
  content_vector, chunk_ids, file_path)
1698
  VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
@@ -1716,45 +1814,6 @@ SQL_TEMPLATES = {
1716
  file_path=EXCLUDED.file_path,
1717
  update_time = CURRENT_TIMESTAMP
1718
  """,
1719
- # SQL for VectorStorage
1720
- # "entities": """SELECT entity_name FROM
1721
- # (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1722
- # FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
1723
- # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1724
- # """,
1725
- # "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
1726
- # (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1727
- # FROM LIGHTRAG_VDB_RELATION where workspace=$1)
1728
- # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1729
- # """,
1730
- # "chunks": """SELECT id FROM
1731
- # (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1732
- # FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
1733
- # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1734
- # """,
1735
- # DROP tables
1736
- "drop_all": """
1737
- DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
1738
- DROP TABLE IF EXISTS LIGHTRAG_DOC_CHUNKS CASCADE;
1739
- DROP TABLE IF EXISTS LIGHTRAG_LLM_CACHE CASCADE;
1740
- DROP TABLE IF EXISTS LIGHTRAG_VDB_ENTITY CASCADE;
1741
- DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
1742
- """,
1743
- "drop_doc_full": """
1744
- DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
1745
- """,
1746
- "drop_doc_chunks": """
1747
- DROP TABLE IF EXISTS LIGHTRAG_DOC_CHUNKS CASCADE;
1748
- """,
1749
- "drop_llm_cache": """
1750
- DROP TABLE IF EXISTS LIGHTRAG_LLM_CACHE CASCADE;
1751
- """,
1752
- "drop_vdb_entity": """
1753
- DROP TABLE IF EXISTS LIGHTRAG_VDB_ENTITY CASCADE;
1754
- """,
1755
- "drop_vdb_relation": """
1756
- DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
1757
- """,
1758
  "relationships": """
1759
  WITH relevant_chunks AS (
1760
  SELECT id as chunk_id
@@ -1795,9 +1854,9 @@ SQL_TEMPLATES = {
1795
  FROM LIGHTRAG_DOC_CHUNKS
1796
  WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
1797
  )
1798
- SELECT id FROM
1799
  (
1800
- SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1801
  FROM LIGHTRAG_DOC_CHUNKS
1802
  where workspace=$1
1803
  AND id IN (SELECT chunk_id FROM relevant_chunks)
@@ -1806,4 +1865,8 @@ SQL_TEMPLATES = {
1806
  ORDER BY distance DESC
1807
  LIMIT $3
1808
  """,
 
 
 
 
1809
  }
 
9
 
10
  from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
11
 
 
12
  from tenacity import (
13
  retry,
14
  retry_if_exception_type,
 
27
  from ..namespace import NameSpace, is_namespace
28
  from ..utils import logger
29
 
 
 
 
 
 
30
  import pipmaster as pm
31
 
32
  if not pm.is_installed("asyncpg"):
 
35
  import asyncpg # type: ignore
36
  from asyncpg import Pool # type: ignore
37
 
38
+ # Get maximum number of graph nodes from environment variable, default is 1000
39
+ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
40
+
41
 
42
  class PostgreSQLDB:
43
  def __init__(self, config: dict[str, Any], **kwargs: Any):
 
115
  )
116
  raise e
117
 
118
+ # Create index for id column in each table
119
+ try:
120
+ index_name = f"idx_{k.lower()}_id"
121
+ check_index_sql = f"""
122
+ SELECT 1 FROM pg_indexes
123
+ WHERE indexname = '{index_name}'
124
+ AND tablename = '{k.lower()}'
125
+ """
126
+ index_exists = await self.query(check_index_sql)
127
+
128
+ if not index_exists:
129
+ create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)"
130
+ logger.info(f"PostgreSQL, Creating index {index_name} on table {k}")
131
+ await self.execute(create_index_sql)
132
+ except Exception as e:
133
+ logger.error(
134
+ f"PostgreSQL, Failed to create index on table {k}, Got: {e}"
135
+ )
136
+
137
  async def query(
138
  self,
139
  sql: str,
 
270
  db: PostgreSQLDB = field(default=None)
271
 
272
  def __post_init__(self):
 
 
273
  self._max_batch_size = self.global_config["embedding_batch_num"]
274
 
275
  async def initialize(self):
 
285
 
286
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
287
  """Get doc_full data by id."""
288
+ sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
289
  params = {"workspace": self.db.workspace, "id": id}
290
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
291
  array_res = await self.db.query(sql, params, multirows=True)
 
299
 
300
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
301
  """Specifically for llm_response_cache."""
302
+ sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
303
  params = {"workspace": self.db.workspace, mode: mode, "id": id}
304
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
305
  array_res = await self.db.query(sql, params, multirows=True)
 
313
  # Query by id
314
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
315
  """Get doc_chunks data by id"""
316
+ sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
317
  ids=",".join([f"'{id}'" for id in ids])
318
  )
319
  params = {"workspace": self.db.workspace}
 
334
 
335
  async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
336
  """Specifically for llm_response_cache."""
337
+ SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
338
  params = {"workspace": self.db.workspace, "status": status}
339
  return await self.db.query(SQL, params, multirows=True)
340
 
 
394
  # PG handles persistence automatically
395
  pass
396
 
397
+ async def delete(self, ids: list[str]) -> None:
398
+ """Delete specific records from storage by their IDs
399
+
400
+ Args:
401
+ ids (list[str]): List of document IDs to be deleted from storage
402
+
403
+ Returns:
404
+ None
405
+ """
406
+ if not ids:
407
+ return
408
+
409
+ table_name = namespace_to_table_name(self.namespace)
410
+ if not table_name:
411
+ logger.error(f"Unknown namespace for deletion: {self.namespace}")
412
+ return
413
+
414
+ delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
415
+
416
+ try:
417
+ await self.db.execute(
418
+ delete_sql, {"workspace": self.db.workspace, "ids": ids}
419
+ )
420
+ logger.debug(
421
+ f"Successfully deleted {len(ids)} records from {self.namespace}"
422
+ )
423
+ except Exception as e:
424
+ logger.error(f"Error while deleting records from {self.namespace}: {e}")
425
+
426
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
427
+ """Delete specific records from storage by cache mode
428
+
429
+ Args:
430
+ modes (list[str]): List of cache modes to be dropped from storage
431
+
432
+ Returns:
433
+ bool: True if successful, False otherwise
434
+ """
435
+ if not modes:
436
+ return False
437
+
438
+ try:
439
+ table_name = namespace_to_table_name(self.namespace)
440
+ if not table_name:
441
+ return False
442
+
443
+ if table_name != "LIGHTRAG_LLM_CACHE":
444
+ return False
445
+
446
+ sql = f"""
447
+ DELETE FROM {table_name}
448
+ WHERE workspace = $1 AND mode = ANY($2)
449
+ """
450
+ params = {"workspace": self.db.workspace, "modes": modes}
451
+
452
+ logger.info(f"Deleting cache by modes: {modes}")
453
+ await self.db.execute(sql, params)
454
+ return True
455
+ except Exception as e:
456
+ logger.error(f"Error deleting cache by modes {modes}: {e}")
457
+ return False
458
+
459
+ async def drop(self) -> dict[str, str]:
460
  """Drop the storage"""
461
+ try:
462
+ table_name = namespace_to_table_name(self.namespace)
463
+ if not table_name:
464
+ return {
465
+ "status": "error",
466
+ "message": f"Unknown namespace: {self.namespace}",
467
+ }
468
+
469
+ drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
470
+ table_name=table_name
471
+ )
472
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
473
+ return {"status": "success", "message": "data dropped"}
474
+ except Exception as e:
475
+ return {"status": "error", "message": str(e)}
476
 
477
 
478
  @final
 
482
 
483
  def __post_init__(self):
484
  self._max_batch_size = self.global_config["embedding_batch_num"]
 
 
485
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
486
  cosine_threshold = config.get("cosine_better_than_threshold")
487
  if cosine_threshold is None:
 
610
  else:
611
  formatted_ids = "NULL"
612
 
613
+ sql = SQL_TEMPLATES[self.namespace].format(
614
  embedding_string=embedding_string, doc_ids=formatted_ids
615
  )
616
  params = {
 
639
  logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
640
  return
641
 
642
+ delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
 
 
 
643
 
644
  try:
645
+ await self.db.execute(
646
+ delete_sql, {"workspace": self.db.workspace, "ids": ids}
647
+ )
648
  logger.debug(
649
  f"Successfully deleted {len(ids)} vectors from {self.namespace}"
650
  )
 
776
  logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
777
  return []
778
 
779
+ async def drop(self) -> dict[str, str]:
780
+ """Drop the storage"""
781
+ try:
782
+ table_name = namespace_to_table_name(self.namespace)
783
+ if not table_name:
784
+ return {
785
+ "status": "error",
786
+ "message": f"Unknown namespace: {self.namespace}",
787
+ }
788
+
789
+ drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
790
+ table_name=table_name
791
+ )
792
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
793
+ return {"status": "success", "message": "data dropped"}
794
+ except Exception as e:
795
+ return {"status": "error", "message": str(e)}
796
+
797
 
798
  @final
799
  @dataclass
 
914
  # PG handles persistence automatically
915
  pass
916
 
917
+ async def delete(self, ids: list[str]) -> None:
918
+ """Delete specific records from storage by their IDs
919
+
920
+ Args:
921
+ ids (list[str]): List of document IDs to be deleted from storage
922
+
923
+ Returns:
924
+ None
925
+ """
926
+ if not ids:
927
+ return
928
+
929
+ table_name = namespace_to_table_name(self.namespace)
930
+ if not table_name:
931
+ logger.error(f"Unknown namespace for deletion: {self.namespace}")
932
+ return
933
+
934
+ delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)"
935
+
936
+ try:
937
+ await self.db.execute(
938
+ delete_sql, {"workspace": self.db.workspace, "ids": ids}
939
+ )
940
+ logger.debug(
941
+ f"Successfully deleted {len(ids)} records from {self.namespace}"
942
+ )
943
+ except Exception as e:
944
+ logger.error(f"Error while deleting records from {self.namespace}: {e}")
945
+
946
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
947
  """Update or insert document status
948
 
 
979
  },
980
  )
981
 
982
+ async def drop(self) -> dict[str, str]:
983
  """Drop the storage"""
984
+ try:
985
+ table_name = namespace_to_table_name(self.namespace)
986
+ if not table_name:
987
+ return {
988
+ "status": "error",
989
+ "message": f"Unknown namespace: {self.namespace}",
990
+ }
991
+
992
+ drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
993
+ table_name=table_name
994
+ )
995
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
996
+ return {"status": "success", "message": "data dropped"}
997
+ except Exception as e:
998
+ return {"status": "error", "message": str(e)}
999
 
1000
 
1001
  class PGGraphQueryException(Exception):
 
1083
  if v.startswith("[") and v.endswith("]"):
1084
  if "::vertex" in v:
1085
  v = v.replace("::vertex", "")
1086
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
 
 
 
1087
 
1088
  elif "::edge" in v:
1089
  v = v.replace("::edge", "")
1090
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
 
 
 
1091
  else:
1092
  print("WARNING: unsupported type")
1093
  continue
 
1096
  dtype = v.split("::")[-1]
1097
  v = v.split("::")[0]
1098
  if dtype == "vertex":
1099
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
 
 
 
1100
  elif dtype == "edge":
1101
+ d[k] = json.loads(v)
 
 
 
 
 
 
 
1102
  else:
1103
+ try:
1104
+ d[k] = (
1105
+ json.loads(v)
1106
+ if isinstance(v, str)
1107
+ and (v.startswith("{") or v.startswith("["))
1108
+ else v
1109
+ )
1110
+ except json.JSONDecodeError:
1111
+ d[k] = v
1112
 
1113
  return d
1114
 
 
1138
  )
1139
  return "{" + ", ".join(props) + "}"
1140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1141
  async def _query(
1142
  self,
1143
  query: str,
 
1188
  return result
1189
 
1190
  async def has_node(self, node_id: str) -> bool:
1191
+ entity_name_label = node_id.strip('"')
1192
 
1193
  query = """SELECT * FROM cypher('%s', $$
1194
+ MATCH (n:base {entity_id: "%s"})
1195
  RETURN count(n) > 0 AS node_exists
1196
  $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
1197
 
 
1200
  return single_result["node_exists"]
1201
 
1202
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
1203
+ src_label = source_node_id.strip('"')
1204
+ tgt_label = target_node_id.strip('"')
1205
 
1206
  query = """SELECT * FROM cypher('%s', $$
1207
+ MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
1208
  RETURN COUNT(r) > 0 AS edge_exists
1209
  $$) AS (edge_exists bool)""" % (
1210
  self.graph_name,
 
1217
  return single_result["edge_exists"]
1218
 
1219
  async def get_node(self, node_id: str) -> dict[str, str] | None:
1220
+ """Get node by its label identifier, return only node properties"""
1221
+
1222
+ label = node_id.strip('"')
1223
  query = """SELECT * FROM cypher('%s', $$
1224
+ MATCH (n:base {entity_id: "%s"})
1225
  RETURN n
1226
  $$) AS (n agtype)""" % (self.graph_name, label)
1227
  record = await self._query(query)
1228
  if record:
1229
  node = record[0]
1230
+ node_dict = node["n"]["properties"]
1231
 
1232
  return node_dict
1233
  return None
1234
 
1235
  async def node_degree(self, node_id: str) -> int:
1236
+ label = node_id.strip('"')
1237
 
1238
  query = """SELECT * FROM cypher('%s', $$
1239
+ MATCH (n:base {entity_id: "%s"})-[]-(x)
1240
  RETURN count(x) AS total_edge_count
1241
  $$) AS (total_edge_count integer)""" % (self.graph_name, label)
1242
  record = (await self._query(query))[0]
1243
  if record:
1244
  edge_count = int(record["total_edge_count"])
 
1245
  return edge_count
1246
 
1247
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
 
1259
  async def get_edge(
1260
  self, source_node_id: str, target_node_id: str
1261
  ) -> dict[str, str] | None:
1262
+ """Get edge properties between two nodes"""
1263
+
1264
+ src_label = source_node_id.strip('"')
1265
+ tgt_label = target_node_id.strip('"')
1266
 
1267
  query = """SELECT * FROM cypher('%s', $$
1268
+ MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
1269
  RETURN properties(r) as edge_properties
1270
  LIMIT 1
1271
  $$) AS (edge_properties agtype)""" % (
 
1284
  Retrieves all edges (relationships) for a particular node identified by its label.
1285
  :return: list of dictionaries containing edge information
1286
  """
1287
+ label = source_node_id.strip('"')
1288
 
1289
  query = """SELECT * FROM cypher('%s', $$
1290
+ MATCH (n:base {entity_id: "%s"})
1291
+ OPTIONAL MATCH (n)-[]-(connected:base)
1292
  RETURN n, connected
1293
  $$) AS (n agtype, connected agtype)""" % (
1294
  self.graph_name,
 
1301
  source_node = record["n"] if record["n"] else None
1302
  connected_node = record["connected"] if record["connected"] else None
1303
 
1304
+ if (
1305
+ source_node
1306
+ and connected_node
1307
+ and "properties" in source_node
1308
+ and "properties" in connected_node
1309
+ ):
1310
+ source_label = source_node["properties"].get("entity_id")
1311
+ target_label = connected_node["properties"].get("entity_id")
 
 
1312
 
1313
+ if source_label and target_label:
1314
+ edges.append((source_label, target_label))
 
 
 
 
 
1315
 
1316
  return edges
1317
 
 
1321
  retry=retry_if_exception_type((PGGraphQueryException,)),
1322
  )
1323
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
1324
+ """
1325
+ Upsert a node in the Neo4j database.
1326
+
1327
+ Args:
1328
+ node_id: The unique identifier for the node (used as label)
1329
+ node_data: Dictionary of node properties
1330
+ """
1331
+ if "entity_id" not in node_data:
1332
+ raise ValueError(
1333
+ "PostgreSQL: node properties must contain an 'entity_id' field"
1334
+ )
1335
+
1336
+ label = node_id.strip('"')
1337
+ properties = self._format_properties(node_data)
1338
 
1339
  query = """SELECT * FROM cypher('%s', $$
1340
+ MERGE (n:base {entity_id: "%s"})
1341
  SET n += %s
1342
  RETURN n
1343
  $$) AS (n agtype)""" % (
1344
  self.graph_name,
1345
  label,
1346
+ properties,
1347
  )
1348
 
1349
  try:
1350
  await self._query(query, readonly=False, upsert=True)
1351
 
1352
+ except Exception:
1353
+ logger.error(f"POSTGRES, upsert_node error on node_id: `{node_id}`")
1354
  raise
1355
 
1356
  @retry(
 
1369
  target_node_id (str): Label of the target node (used as identifier)
1370
  edge_data (dict): dictionary of properties to set on the edge
1371
  """
1372
+ src_label = source_node_id.strip('"')
1373
+ tgt_label = target_node_id.strip('"')
1374
+ edge_properties = self._format_properties(edge_data)
1375
 
1376
  query = """SELECT * FROM cypher('%s', $$
1377
+ MATCH (source:base {entity_id: "%s"})
1378
  WITH source
1379
+ MATCH (target:base {entity_id: "%s"})
1380
  MERGE (source)-[r:DIRECTED]->(target)
1381
  SET r += %s
1382
  RETURN r
 
1384
  self.graph_name,
1385
  src_label,
1386
  tgt_label,
1387
+ edge_properties,
1388
  )
1389
 
1390
  try:
1391
  await self._query(query, readonly=False, upsert=True)
1392
 
1393
+ except Exception:
1394
+ logger.error(
1395
+ f"POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`"
1396
+ )
1397
  raise
1398
 
1399
  async def _node2vec_embed(self):
 
1406
  Args:
1407
  node_id (str): The ID of the node to delete.
1408
  """
1409
+ label = node_id.strip('"')
1410
 
1411
  query = """SELECT * FROM cypher('%s', $$
1412
+ MATCH (n:base {entity_id: "%s"})
1413
  DETACH DELETE n
1414
  $$) AS (n agtype)""" % (self.graph_name, label)
1415
 
 
1426
  Args:
1427
  node_ids (list[str]): A list of node IDs to remove.
1428
  """
1429
+ node_ids = [node_id.strip('"') for node_id in node_ids]
1430
+ node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids])
 
 
1431
 
1432
  query = """SELECT * FROM cypher('%s', $$
1433
+ MATCH (n:base)
1434
+ WHERE n.entity_id IN [%s]
1435
  DETACH DELETE n
1436
  $$) AS (n agtype)""" % (self.graph_name, node_id_list)
1437
 
 
1448
  Args:
1449
  edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
1450
  """
1451
+ for source, target in edges:
1452
+ src_label = source.strip('"')
1453
+ tgt_label = target.strip('"')
 
 
 
 
 
1454
 
1455
+ query = """SELECT * FROM cypher('%s', $$
1456
+ MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"})
1457
+ DELETE r
1458
+ $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
 
1459
 
1460
+ try:
1461
+ await self._query(query, readonly=False)
1462
+ logger.debug(f"Deleted edge from '{source}' to '{target}'")
1463
+ except Exception as e:
1464
+ logger.error(f"Error during edge deletion: {str(e)}")
1465
+ raise
1466
 
1467
  async def get_all_labels(self) -> list[str]:
1468
  """
 
1473
  """
1474
  query = (
1475
  """SELECT * FROM cypher('%s', $$
1476
+ MATCH (n:base)
1477
+ WHERE n.entity_id IS NOT NULL
1478
+ RETURN DISTINCT n.entity_id AS label
1479
+ ORDER BY n.entity_id
1480
  $$) AS (label text)"""
1481
  % self.graph_name
1482
  )
1483
 
1484
  results = await self._query(query)
1485
+ labels = [result["label"] for result in results]
 
1486
  return labels
1487
 
1488
  async def embed_nodes(
 
1504
  return await embed_func()
1505
 
1506
  async def get_knowledge_graph(
1507
+ self,
1508
+ node_label: str,
1509
+ max_depth: int = 3,
1510
+ max_nodes: int = MAX_GRAPH_NODES,
1511
  ) -> KnowledgeGraph:
1512
  """
1513
+ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
1514
 
1515
  Args:
1516
+ node_label: Label of the starting node, * means all nodes
1517
+ max_depth: Maximum depth of the subgraph, Defaults to 3
1518
+ max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed)
1519
 
1520
  Returns:
1521
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
1522
+ indicating whether the graph was truncated due to max_nodes limit
1523
  """
1524
+ # First, count the total number of nodes that would be returned without limit
1525
+ if node_label == "*":
1526
+ count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1527
+ MATCH (n:base)
1528
+ RETURN count(distinct n) AS total_nodes
1529
+ $$) AS (total_nodes bigint)"""
1530
+ else:
1531
+ strip_label = node_label.strip('"')
1532
+ count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1533
+ MATCH (n:base {{entity_id: "{strip_label}"}})
1534
+ OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
1535
+ RETURN count(distinct m) AS total_nodes
1536
+ $$) AS (total_nodes bigint)"""
1537
+
1538
+ count_result = await self._query(count_query)
1539
+ total_nodes = count_result[0]["total_nodes"] if count_result else 0
1540
+ is_truncated = total_nodes > max_nodes
1541
+
1542
+ # Now get the actual data with limit
1543
  if node_label == "*":
1544
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1545
+ MATCH (n:base)
1546
+ OPTIONAL MATCH (n)-[r]->(target:base)
1547
+ RETURN collect(distinct n) AS n, collect(distinct r) AS r
1548
+ LIMIT {max_nodes}
1549
+ $$) AS (n agtype, r agtype)"""
1550
  else:
1551
+ strip_label = node_label.strip('"')
1552
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1553
+ MATCH (n:base {{entity_id: "{strip_label}"}})
1554
+ OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
1555
+ RETURN nodes(p) AS n, relationships(p) AS r
1556
+ LIMIT {max_nodes}
1557
+ $$) AS (n agtype, r agtype)"""
1558
 
1559
  results = await self._query(query)
1560
 
1561
+ # Process the query results with deduplication by node and edge IDs
1562
+ nodes_dict = {}
1563
+ edges_dict = {}
1564
+ for result in results:
1565
+ # Handle single node cases
1566
+ if result.get("n") and isinstance(result["n"], dict):
1567
+ node_id = str(result["n"]["id"])
1568
+ if node_id not in nodes_dict:
1569
+ nodes_dict[node_id] = KnowledgeGraphNode(
1570
+ id=node_id,
1571
+ labels=[result["n"]["properties"]["entity_id"]],
1572
+ properties=result["n"]["properties"],
 
 
 
 
 
 
 
 
 
1573
  )
1574
+ # Handle node list cases
1575
+ elif result.get("n") and isinstance(result["n"], list):
1576
+ for node in result["n"]:
1577
+ if isinstance(node, dict) and "id" in node:
1578
+ node_id = str(node["id"])
1579
+ if node_id not in nodes_dict and "properties" in node:
1580
+ nodes_dict[node_id] = KnowledgeGraphNode(
1581
+ id=node_id,
1582
+ labels=[node["properties"]["entity_id"]],
1583
+ properties=node["properties"],
1584
+ )
1585
 
1586
+ # Handle single edge cases
1587
+ if result.get("r") and isinstance(result["r"], dict):
1588
+ edge_id = str(result["r"]["id"])
1589
+ if edge_id not in edges_dict:
1590
+ edges_dict[edge_id] = KnowledgeGraphEdge(
1591
+ id=edge_id,
1592
+ type="DIRECTED",
1593
+ source=str(result["r"]["start_id"]),
1594
+ target=str(result["r"]["end_id"]),
1595
+ properties=result["r"]["properties"],
1596
+ )
1597
+ # Handle edge list cases
1598
+ elif result.get("r") and isinstance(result["r"], list):
1599
+ for edge in result["r"]:
1600
+ if isinstance(edge, dict) and "id" in edge:
1601
+ edge_id = str(edge["id"])
1602
+ if edge_id not in edges_dict:
1603
+ edges_dict[edge_id] = KnowledgeGraphEdge(
1604
+ id=edge_id,
1605
+ type="DIRECTED",
1606
+ source=str(edge["start_id"]),
1607
+ target=str(edge["end_id"]),
1608
+ properties=edge["properties"],
1609
+ )
1610
 
1611
+ # Construct and return the KnowledgeGraph with deduplicated nodes and edges
1612
  kg = KnowledgeGraph(
1613
+ nodes=list(nodes_dict.values()),
1614
+ edges=list(edges_dict.values()),
1615
+ is_truncated=is_truncated,
 
 
 
 
 
 
 
 
 
 
 
1616
  )
1617
 
1618
+ logger.info(
1619
+ f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
1620
+ )
1621
  return kg
1622
 
1623
+ async def drop(self) -> dict[str, str]:
1624
  """Drop the storage"""
1625
+ try:
1626
+ drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1627
+ MATCH (n)
1628
+ DETACH DELETE n
1629
+ $$) AS (result agtype)"""
1630
+
1631
+ await self._query(drop_query, readonly=False)
1632
+ return {"status": "success", "message": "graph data dropped"}
1633
+ except Exception as e:
1634
+ logger.error(f"Error dropping graph: {e}")
1635
+ return {"status": "error", "message": str(e)}
1636
 
1637
 
1638
  NAMESPACE_TABLE_MAP = {
 
1745
  FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
1746
  """,
1747
  "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
1748
+ chunk_order_index, full_doc_id, file_path
1749
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
1750
  """,
1751
  "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
 
1758
  FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
1759
  """,
1760
  "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
1761
+ chunk_order_index, full_doc_id, file_path
1762
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
1763
  """,
1764
  "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode
 
1790
  file_path=EXCLUDED.file_path,
1791
  update_time = CURRENT_TIMESTAMP
1792
  """,
1793
+ # SQL for VectorStorage
1794
  "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
1795
  content_vector, chunk_ids, file_path)
1796
  VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7)
 
1814
  file_path=EXCLUDED.file_path,
1815
  update_time = CURRENT_TIMESTAMP
1816
  """,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1817
  "relationships": """
1818
  WITH relevant_chunks AS (
1819
  SELECT id as chunk_id
 
1854
  FROM LIGHTRAG_DOC_CHUNKS
1855
  WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
1856
  )
1857
+ SELECT id, content, file_path FROM
1858
  (
1859
+ SELECT id, content, file_path, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1860
  FROM LIGHTRAG_DOC_CHUNKS
1861
  where workspace=$1
1862
  AND id IN (SELECT chunk_id FROM relevant_chunks)
 
1865
  ORDER BY distance DESC
1866
  LIMIT $3
1867
  """,
1868
+ # DROP tables
1869
+ "drop_specifiy_table_workspace": """
1870
+ DELETE FROM {table_name} WHERE workspace=$1
1871
+ """,
1872
  }
lightrag/kg/qdrant_impl.py CHANGED
@@ -8,17 +8,15 @@ import uuid
8
  from ..utils import logger
9
  from ..base import BaseVectorStorage
10
  import configparser
11
-
12
-
13
- config = configparser.ConfigParser()
14
- config.read("config.ini", "utf-8")
15
-
16
  import pipmaster as pm
17
 
18
  if not pm.is_installed("qdrant-client"):
19
  pm.install("qdrant-client")
20
 
21
- from qdrant_client import QdrantClient, models
 
 
 
22
 
23
 
24
  def compute_mdhash_id_for_qdrant(
@@ -275,3 +273,92 @@ class QdrantVectorDBStorage(BaseVectorStorage):
275
  except Exception as e:
276
  logger.error(f"Error searching for prefix '{prefix}': {e}")
277
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from ..utils import logger
9
  from ..base import BaseVectorStorage
10
  import configparser
 
 
 
 
 
11
  import pipmaster as pm
12
 
13
  if not pm.is_installed("qdrant-client"):
14
  pm.install("qdrant-client")
15
 
16
+ from qdrant_client import QdrantClient, models # type: ignore
17
+
18
+ config = configparser.ConfigParser()
19
+ config.read("config.ini", "utf-8")
20
 
21
 
22
  def compute_mdhash_id_for_qdrant(
 
273
  except Exception as e:
274
  logger.error(f"Error searching for prefix '{prefix}': {e}")
275
  return []
276
+
277
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
278
+ """Get vector data by its ID
279
+
280
+ Args:
281
+ id: The unique identifier of the vector
282
+
283
+ Returns:
284
+ The vector data if found, or None if not found
285
+ """
286
+ try:
287
+ # Convert to Qdrant compatible ID
288
+ qdrant_id = compute_mdhash_id_for_qdrant(id)
289
+
290
+ # Retrieve the point by ID
291
+ result = self._client.retrieve(
292
+ collection_name=self.namespace,
293
+ ids=[qdrant_id],
294
+ with_payload=True,
295
+ )
296
+
297
+ if not result:
298
+ return None
299
+
300
+ return result[0].payload
301
+ except Exception as e:
302
+ logger.error(f"Error retrieving vector data for ID {id}: {e}")
303
+ return None
304
+
305
+ async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
306
+ """Get multiple vector data by their IDs
307
+
308
+ Args:
309
+ ids: List of unique identifiers
310
+
311
+ Returns:
312
+ List of vector data objects that were found
313
+ """
314
+ if not ids:
315
+ return []
316
+
317
+ try:
318
+ # Convert to Qdrant compatible IDs
319
+ qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids]
320
+
321
+ # Retrieve the points by IDs
322
+ results = self._client.retrieve(
323
+ collection_name=self.namespace,
324
+ ids=qdrant_ids,
325
+ with_payload=True,
326
+ )
327
+
328
+ return [point.payload for point in results]
329
+ except Exception as e:
330
+ logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
331
+ return []
332
+
333
+ async def drop(self) -> dict[str, str]:
334
+ """Drop all vector data from storage and clean up resources
335
+
336
+ This method will delete all data from the Qdrant collection.
337
+
338
+ Returns:
339
+ dict[str, str]: Operation status and message
340
+ - On success: {"status": "success", "message": "data dropped"}
341
+ - On failure: {"status": "error", "message": "<error details>"}
342
+ """
343
+ try:
344
+ # Delete the collection and recreate it
345
+ if self._client.collection_exists(self.namespace):
346
+ self._client.delete_collection(self.namespace)
347
+
348
+ # Recreate the collection
349
+ QdrantVectorDBStorage.create_collection_if_not_exist(
350
+ self._client,
351
+ self.namespace,
352
+ vectors_config=models.VectorParams(
353
+ size=self.embedding_func.embedding_dim,
354
+ distance=models.Distance.COSINE,
355
+ ),
356
+ )
357
+
358
+ logger.info(
359
+ f"Process {os.getpid()} drop Qdrant collection {self.namespace}"
360
+ )
361
+ return {"status": "success", "message": "data dropped"}
362
+ except Exception as e:
363
+ logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}")
364
+ return {"status": "error", "message": str(e)}
lightrag/kg/redis_impl.py CHANGED
@@ -12,6 +12,7 @@ if not pm.is_installed("redis"):
12
  from redis.asyncio import Redis, ConnectionPool
13
  from redis.exceptions import RedisError, ConnectionError
14
  from lightrag.utils import logger, compute_mdhash_id
 
15
  from lightrag.base import BaseKVStorage
16
  import json
17
 
@@ -121,7 +122,11 @@ class RedisKVStorage(BaseKVStorage):
121
  except json.JSONEncodeError as e:
122
  logger.error(f"JSON encode error during upsert: {e}")
123
  raise
124
-
 
 
 
 
125
  async def delete(self, ids: list[str]) -> None:
126
  """Delete entries with specified IDs"""
127
  if not ids:
@@ -138,71 +143,52 @@ class RedisKVStorage(BaseKVStorage):
138
  f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
139
  )
140
 
141
- async def delete_entity(self, entity_name: str) -> None:
142
- """Delete an entity by name"""
143
- try:
144
- entity_id = compute_mdhash_id(entity_name, prefix="ent-")
145
- logger.debug(
146
- f"Attempting to delete entity {entity_name} with ID {entity_id}"
147
- )
148
 
149
- async with self._get_redis_connection() as redis:
150
- result = await redis.delete(f"{self.namespace}:{entity_id}")
151
 
152
- if result:
153
- logger.debug(f"Successfully deleted entity {entity_name}")
154
- else:
155
- logger.debug(f"Entity {entity_name} not found in storage")
156
- except Exception as e:
157
- logger.error(f"Error deleting entity {entity_name}: {e}")
 
 
 
158
 
159
- async def delete_entity_relation(self, entity_name: str) -> None:
160
- """Delete all relations associated with an entity"""
161
  try:
162
- async with self._get_redis_connection() as redis:
163
- cursor = 0
164
- relation_keys = []
165
- pattern = f"{self.namespace}:*"
166
-
167
- while True:
168
- cursor, keys = await redis.scan(cursor, match=pattern)
169
-
170
- # Process keys in batches
 
 
 
 
 
 
 
171
  pipe = redis.pipeline()
172
  for key in keys:
173
- pipe.get(key)
174
- values = await pipe.execute()
175
-
176
- for key, value in zip(keys, values):
177
- if value:
178
- try:
179
- data = json.loads(value)
180
- if (
181
- data.get("src_id") == entity_name
182
- or data.get("tgt_id") == entity_name
183
- ):
184
- relation_keys.append(key)
185
- except json.JSONDecodeError:
186
- logger.warning(f"Invalid JSON in key {key}")
187
- continue
188
-
189
- if cursor == 0:
190
- break
191
-
192
- # Delete relations in batches
193
- if relation_keys:
194
- # Delete in chunks to avoid too many arguments
195
- chunk_size = 1000
196
- for i in range(0, len(relation_keys), chunk_size):
197
- chunk = relation_keys[i:i + chunk_size]
198
- deleted = await redis.delete(*chunk)
199
- logger.debug(f"Deleted {deleted} relations for {entity_name} (batch {i//chunk_size + 1})")
200
  else:
201
- logger.debug(f"No relations found for entity {entity_name}")
 
202
 
203
- except Exception as e:
204
- logger.error(f"Error deleting relations for {entity_name}: {e}")
 
205
 
206
- async def index_done_callback(self) -> None:
207
- # Redis handles persistence automatically
208
- pass
 
12
  from redis.asyncio import Redis, ConnectionPool
13
  from redis.exceptions import RedisError, ConnectionError
14
  from lightrag.utils import logger, compute_mdhash_id
15
+
16
  from lightrag.base import BaseKVStorage
17
  import json
18
 
 
122
  except json.JSONEncodeError as e:
123
  logger.error(f"JSON encode error during upsert: {e}")
124
  raise
125
+
126
+ async def index_done_callback(self) -> None:
127
+ # Redis handles persistence automatically
128
+ pass
129
+
130
  async def delete(self, ids: list[str]) -> None:
131
  """Delete entries with specified IDs"""
132
  if not ids:
 
143
  f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
144
  )
145
 
146
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
147
+ """Delete specific records from storage by by cache mode
 
 
 
 
 
148
 
149
+ Importance notes for Redis storage:
150
+ 1. This will immediately delete the specified cache modes from Redis
151
 
152
+ Args:
153
+ modes (list[str]): List of cache mode to be drop from storage
154
+
155
+ Returns:
156
+ True: if the cache drop successfully
157
+ False: if the cache drop failed
158
+ """
159
+ if not modes:
160
+ return False
161
 
 
 
162
  try:
163
+ await self.delete(modes)
164
+ return True
165
+ except Exception:
166
+ return False
167
+
168
+ async def drop(self) -> dict[str, str]:
169
+ """Drop the storage by removing all keys under the current namespace.
170
+
171
+ Returns:
172
+ dict[str, str]: Status of the operation with keys 'status' and 'message'
173
+ """
174
+ async with self._get_redis_connection() as redis:
175
+ try:
176
+ keys = await redis.keys(f"{self.namespace}:*")
177
+
178
+ if keys:
179
  pipe = redis.pipeline()
180
  for key in keys:
181
+ pipe.delete(key)
182
+ results = await pipe.execute()
183
+ deleted_count = sum(results)
184
+
185
+ logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
186
+ return {"status": "success", "message": f"{deleted_count} keys dropped"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  else:
188
+ logger.info(f"No keys found to drop in {self.namespace}")
189
+ return {"status": "success", "message": "no keys to drop"}
190
 
191
+ except Exception as e:
192
+ logger.error(f"Error dropping keys from {self.namespace}: {e}")
193
+ return {"status": "error", "message": str(e)}
194
 
 
 
 
lightrag/kg/tidb_impl.py CHANGED
@@ -20,7 +20,7 @@ if not pm.is_installed("pymysql"):
20
  if not pm.is_installed("sqlalchemy"):
21
  pm.install("sqlalchemy")
22
 
23
- from sqlalchemy import create_engine, text
24
 
25
 
26
  class TiDB:
@@ -278,6 +278,86 @@ class TiDBKVStorage(BaseKVStorage):
278
  # Ti handles persistence automatically
279
  pass
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  @final
283
  @dataclass
@@ -406,16 +486,91 @@ class TiDBVectorDBStorage(BaseVectorStorage):
406
  params = {"workspace": self.db.workspace, "status": status}
407
  return await self.db.query(SQL, params, multirows=True)
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  async def delete_entity(self, entity_name: str) -> None:
410
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
  async def delete_entity_relation(self, entity_name: str) -> None:
413
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  async def index_done_callback(self) -> None:
416
  # Ti handles persistence automatically
417
  pass
418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
420
  """Search for records with IDs starting with a specific prefix.
421
 
@@ -710,6 +865,18 @@ class TiDBGraphStorage(BaseGraphStorage):
710
  # Ti handles persistence automatically
711
  pass
712
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  async def delete_node(self, node_id: str) -> None:
714
  """Delete a node and all its related edges
715
 
@@ -1129,4 +1296,6 @@ SQL_TEMPLATES = {
1129
  FROM LIGHTRAG_DOC_CHUNKS
1130
  WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
1131
  """,
 
 
1132
  }
 
20
  if not pm.is_installed("sqlalchemy"):
21
  pm.install("sqlalchemy")
22
 
23
+ from sqlalchemy import create_engine, text # type: ignore
24
 
25
 
26
  class TiDB:
 
278
  # Ti handles persistence automatically
279
  pass
280
 
281
+ async def delete(self, ids: list[str]) -> None:
282
+ """Delete records with specified IDs from the storage.
283
+
284
+ Args:
285
+ ids: List of record IDs to be deleted
286
+ """
287
+ if not ids:
288
+ return
289
+
290
+ try:
291
+ table_name = namespace_to_table_name(self.namespace)
292
+ id_field = namespace_to_id(self.namespace)
293
+
294
+ if not table_name or not id_field:
295
+ logger.error(f"Unknown namespace for deletion: {self.namespace}")
296
+ return
297
+
298
+ ids_list = ",".join([f"'{id}'" for id in ids])
299
+ delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
300
+
301
+ await self.db.execute(delete_sql, {"workspace": self.db.workspace})
302
+ logger.info(
303
+ f"Successfully deleted {len(ids)} records from {self.namespace}"
304
+ )
305
+ except Exception as e:
306
+ logger.error(f"Error deleting records from {self.namespace}: {e}")
307
+
308
+ async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
309
+ """Delete specific records from storage by cache mode
310
+
311
+ Args:
312
+ modes (list[str]): List of cache modes to be dropped from storage
313
+
314
+ Returns:
315
+ bool: True if successful, False otherwise
316
+ """
317
+ if not modes:
318
+ return False
319
+
320
+ try:
321
+ table_name = namespace_to_table_name(self.namespace)
322
+ if not table_name:
323
+ return False
324
+
325
+ if table_name != "LIGHTRAG_LLM_CACHE":
326
+ return False
327
+
328
+ # 构建MySQL风格的IN查询
329
+ modes_list = ", ".join([f"'{mode}'" for mode in modes])
330
+ sql = f"""
331
+ DELETE FROM {table_name}
332
+ WHERE workspace = :workspace
333
+ AND mode IN ({modes_list})
334
+ """
335
+
336
+ logger.info(f"Deleting cache by modes: {modes}")
337
+ await self.db.execute(sql, {"workspace": self.db.workspace})
338
+ return True
339
+ except Exception as e:
340
+ logger.error(f"Error deleting cache by modes {modes}: {e}")
341
+ return False
342
+
343
+ async def drop(self) -> dict[str, str]:
344
+ """Drop the storage"""
345
+ try:
346
+ table_name = namespace_to_table_name(self.namespace)
347
+ if not table_name:
348
+ return {
349
+ "status": "error",
350
+ "message": f"Unknown namespace: {self.namespace}",
351
+ }
352
+
353
+ drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
354
+ table_name=table_name
355
+ )
356
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
357
+ return {"status": "success", "message": "data dropped"}
358
+ except Exception as e:
359
+ return {"status": "error", "message": str(e)}
360
+
361
 
362
  @final
363
  @dataclass
 
486
  params = {"workspace": self.db.workspace, "status": status}
487
  return await self.db.query(SQL, params, multirows=True)
488
 
489
+ async def delete(self, ids: list[str]) -> None:
490
+ """Delete vectors with specified IDs from the storage.
491
+
492
+ Args:
493
+ ids: List of vector IDs to be deleted
494
+ """
495
+ if not ids:
496
+ return
497
+
498
+ table_name = namespace_to_table_name(self.namespace)
499
+ id_field = namespace_to_id(self.namespace)
500
+
501
+ if not table_name or not id_field:
502
+ logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
503
+ return
504
+
505
+ ids_list = ",".join([f"'{id}'" for id in ids])
506
+ delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
507
+
508
+ try:
509
+ await self.db.execute(delete_sql, {"workspace": self.db.workspace})
510
+ logger.debug(
511
+ f"Successfully deleted {len(ids)} vectors from {self.namespace}"
512
+ )
513
+ except Exception as e:
514
+ logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
515
+
516
  async def delete_entity(self, entity_name: str) -> None:
517
+ """Delete an entity by its name from the vector storage.
518
+
519
+ Args:
520
+ entity_name: The name of the entity to delete
521
+ """
522
+ try:
523
+ # Construct SQL to delete the entity
524
+ delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES
525
+ WHERE workspace = :workspace AND name = :entity_name"""
526
+
527
+ await self.db.execute(
528
+ delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
529
+ )
530
+ logger.debug(f"Successfully deleted entity {entity_name}")
531
+ except Exception as e:
532
+ logger.error(f"Error deleting entity {entity_name}: {e}")
533
 
534
  async def delete_entity_relation(self, entity_name: str) -> None:
535
+ """Delete all relations associated with an entity.
536
+
537
+ Args:
538
+ entity_name: The name of the entity whose relations should be deleted
539
+ """
540
+ try:
541
+ # Delete relations where the entity is either the source or target
542
+ delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES
543
+ WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)"""
544
+
545
+ await self.db.execute(
546
+ delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
547
+ )
548
+ logger.debug(f"Successfully deleted relations for entity {entity_name}")
549
+ except Exception as e:
550
+ logger.error(f"Error deleting relations for entity {entity_name}: {e}")
551
 
552
  async def index_done_callback(self) -> None:
553
  # Ti handles persistence automatically
554
  pass
555
 
556
+ async def drop(self) -> dict[str, str]:
557
+ """Drop the storage"""
558
+ try:
559
+ table_name = namespace_to_table_name(self.namespace)
560
+ if not table_name:
561
+ return {
562
+ "status": "error",
563
+ "message": f"Unknown namespace: {self.namespace}",
564
+ }
565
+
566
+ drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
567
+ table_name=table_name
568
+ )
569
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
570
+ return {"status": "success", "message": "data dropped"}
571
+ except Exception as e:
572
+ return {"status": "error", "message": str(e)}
573
+
574
  async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
575
  """Search for records with IDs starting with a specific prefix.
576
 
 
865
  # Ti handles persistence automatically
866
  pass
867
 
868
+ async def drop(self) -> dict[str, str]:
869
+ """Drop the storage"""
870
+ try:
871
+ drop_sql = """
872
+ DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace;
873
+ DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace;
874
+ """
875
+ await self.db.execute(drop_sql, {"workspace": self.db.workspace})
876
+ return {"status": "success", "message": "graph data dropped"}
877
+ except Exception as e:
878
+ return {"status": "error", "message": str(e)}
879
+
880
  async def delete_node(self, node_id: str) -> None:
881
  """Delete a node and all its related edges
882
 
 
1296
  FROM LIGHTRAG_DOC_CHUNKS
1297
  WHERE chunk_id LIKE :prefix_pattern AND workspace = :workspace
1298
  """,
1299
+ # Drop tables
1300
+ "drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace",
1301
  }
lightrag/lightrag.py CHANGED
@@ -13,7 +13,6 @@ import pandas as pd
13
 
14
 
15
  from lightrag.kg import (
16
- STORAGE_ENV_REQUIREMENTS,
17
  STORAGES,
18
  verify_storage_implementation,
19
  )
@@ -230,6 +229,7 @@ class LightRAG:
230
  vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
231
  """Additional parameters for vector database storage."""
232
 
 
233
  namespace_prefix: str = field(default="")
234
  """Prefix for namespacing stored data across different environments."""
235
 
@@ -510,36 +510,22 @@ class LightRAG:
510
  self,
511
  node_label: str,
512
  max_depth: int = 3,
513
- min_degree: int = 0,
514
- inclusive: bool = False,
515
  ) -> KnowledgeGraph:
516
  """Get knowledge graph for a given label
517
 
518
  Args:
519
  node_label (str): Label to get knowledge graph for
520
  max_depth (int): Maximum depth of graph
521
- min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
522
- inclusive (bool, optional): Whether to use inclusive search mode. Defaults to False.
523
 
524
  Returns:
525
  KnowledgeGraph: Knowledge graph containing nodes and edges
526
  """
527
- # get params supported by get_knowledge_graph of specified storage
528
- import inspect
529
 
530
- storage_params = inspect.signature(
531
- self.chunk_entity_relation_graph.get_knowledge_graph
532
- ).parameters
533
-
534
- kwargs = {"node_label": node_label, "max_depth": max_depth}
535
-
536
- if "min_degree" in storage_params and min_degree > 0:
537
- kwargs["min_degree"] = min_degree
538
-
539
- if "inclusive" in storage_params:
540
- kwargs["inclusive"] = inclusive
541
-
542
- return await self.chunk_entity_relation_graph.get_knowledge_graph(**kwargs)
543
 
544
  def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
545
  import_path = STORAGES[storage_name]
@@ -1449,6 +1435,7 @@ class LightRAG:
1449
  loop = always_get_an_event_loop()
1450
  return loop.run_until_complete(self.adelete_by_entity(entity_name))
1451
 
 
1452
  async def adelete_by_entity(self, entity_name: str) -> None:
1453
  try:
1454
  await self.entities_vdb.delete_entity(entity_name)
@@ -1486,6 +1473,7 @@ class LightRAG:
1486
  self.adelete_by_relation(source_entity, target_entity)
1487
  )
1488
 
 
1489
  async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
1490
  """Asynchronously delete a relation between two entities.
1491
 
@@ -1494,6 +1482,7 @@ class LightRAG:
1494
  target_entity: Name of the target entity
1495
  """
1496
  try:
 
1497
  # Check if the relation exists
1498
  edge_exists = await self.chunk_entity_relation_graph.has_edge(
1499
  source_entity, target_entity
@@ -1554,6 +1543,7 @@ class LightRAG:
1554
  """
1555
  return await self.doc_status.get_docs_by_status(status)
1556
 
 
1557
  async def adelete_by_doc_id(self, doc_id: str) -> None:
1558
  """Delete a document and all its related data
1559
 
@@ -1586,6 +1576,8 @@ class LightRAG:
1586
  chunk_ids = set(related_chunks.keys())
1587
  logger.debug(f"Found {len(chunk_ids)} chunks to delete")
1588
 
 
 
1589
  # 3. Before deleting, check the related entities and relationships for these chunks
1590
  for chunk_id in chunk_ids:
1591
  # Check entities
@@ -1857,24 +1849,6 @@ class LightRAG:
1857
 
1858
  return result
1859
 
1860
- def check_storage_env_vars(self, storage_name: str) -> None:
1861
- """Check if all required environment variables for storage implementation exist
1862
-
1863
- Args:
1864
- storage_name: Storage implementation name
1865
-
1866
- Raises:
1867
- ValueError: If required environment variables are missing
1868
- """
1869
- required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
1870
- missing_vars = [var for var in required_vars if var not in os.environ]
1871
-
1872
- if missing_vars:
1873
- raise ValueError(
1874
- f"Storage implementation '{storage_name}' requires the following "
1875
- f"environment variables: {', '.join(missing_vars)}"
1876
- )
1877
-
1878
  async def aclear_cache(self, modes: list[str] | None = None) -> None:
1879
  """Clear cache data from the LLM response cache storage.
1880
 
@@ -1906,12 +1880,18 @@ class LightRAG:
1906
  try:
1907
  # Reset the cache storage for specified mode
1908
  if modes:
1909
- await self.llm_response_cache.delete(modes)
1910
- logger.info(f"Cleared cache for modes: {modes}")
 
 
 
1911
  else:
1912
  # Clear all modes
1913
- await self.llm_response_cache.delete(valid_modes)
1914
- logger.info("Cleared all cache")
 
 
 
1915
 
1916
  await self.llm_response_cache.index_done_callback()
1917
 
@@ -1922,6 +1902,7 @@ class LightRAG:
1922
  """Synchronous version of aclear_cache."""
1923
  return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
1924
 
 
1925
  async def aedit_entity(
1926
  self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
1927
  ) -> dict[str, Any]:
@@ -2134,6 +2115,7 @@ class LightRAG:
2134
  ]
2135
  )
2136
 
 
2137
  async def aedit_relation(
2138
  self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
2139
  ) -> dict[str, Any]:
@@ -2448,6 +2430,7 @@ class LightRAG:
2448
  self.acreate_relation(source_entity, target_entity, relation_data)
2449
  )
2450
 
 
2451
  async def amerge_entities(
2452
  self,
2453
  source_entities: list[str],
 
13
 
14
 
15
  from lightrag.kg import (
 
16
  STORAGES,
17
  verify_storage_implementation,
18
  )
 
229
  vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
230
  """Additional parameters for vector database storage."""
231
 
232
+ # TODO:deprecated, remove in the future, use WORKSPACE instead
233
  namespace_prefix: str = field(default="")
234
  """Prefix for namespacing stored data across different environments."""
235
 
 
510
  self,
511
  node_label: str,
512
  max_depth: int = 3,
513
+ max_nodes: int = 1000,
 
514
  ) -> KnowledgeGraph:
515
  """Get knowledge graph for a given label
516
 
517
  Args:
518
  node_label (str): Label to get knowledge graph for
519
  max_depth (int): Maximum depth of graph
520
+ max_nodes (int, optional): Maximum number of nodes to return. Defaults to 1000.
 
521
 
522
  Returns:
523
  KnowledgeGraph: Knowledge graph containing nodes and edges
524
  """
 
 
525
 
526
+ return await self.chunk_entity_relation_graph.get_knowledge_graph(
527
+ node_label, max_depth, max_nodes
528
+ )
 
 
 
 
 
 
 
 
 
 
529
 
530
  def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
531
  import_path = STORAGES[storage_name]
 
1435
  loop = always_get_an_event_loop()
1436
  return loop.run_until_complete(self.adelete_by_entity(entity_name))
1437
 
1438
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
1439
  async def adelete_by_entity(self, entity_name: str) -> None:
1440
  try:
1441
  await self.entities_vdb.delete_entity(entity_name)
 
1473
  self.adelete_by_relation(source_entity, target_entity)
1474
  )
1475
 
1476
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
1477
  async def adelete_by_relation(self, source_entity: str, target_entity: str) -> None:
1478
  """Asynchronously delete a relation between two entities.
1479
 
 
1482
  target_entity: Name of the target entity
1483
  """
1484
  try:
1485
+ # TODO: check if has_edge function works on reverse relation
1486
  # Check if the relation exists
1487
  edge_exists = await self.chunk_entity_relation_graph.has_edge(
1488
  source_entity, target_entity
 
1543
  """
1544
  return await self.doc_status.get_docs_by_status(status)
1545
 
1546
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
1547
  async def adelete_by_doc_id(self, doc_id: str) -> None:
1548
  """Delete a document and all its related data
1549
 
 
1576
  chunk_ids = set(related_chunks.keys())
1577
  logger.debug(f"Found {len(chunk_ids)} chunks to delete")
1578
 
1579
+ # TODO: self.entities_vdb.client_storage only works for local storage, need to fix this
1580
+
1581
  # 3. Before deleting, check the related entities and relationships for these chunks
1582
  for chunk_id in chunk_ids:
1583
  # Check entities
 
1849
 
1850
  return result
1851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1852
  async def aclear_cache(self, modes: list[str] | None = None) -> None:
1853
  """Clear cache data from the LLM response cache storage.
1854
 
 
1880
  try:
1881
  # Reset the cache storage for specified mode
1882
  if modes:
1883
+ success = await self.llm_response_cache.drop_cache_by_modes(modes)
1884
+ if success:
1885
+ logger.info(f"Cleared cache for modes: {modes}")
1886
+ else:
1887
+ logger.warning(f"Failed to clear cache for modes: {modes}")
1888
  else:
1889
  # Clear all modes
1890
+ success = await self.llm_response_cache.drop_cache_by_modes(valid_modes)
1891
+ if success:
1892
+ logger.info("Cleared all cache")
1893
+ else:
1894
+ logger.warning("Failed to clear all cache")
1895
 
1896
  await self.llm_response_cache.index_done_callback()
1897
 
 
1902
  """Synchronous version of aclear_cache."""
1903
  return always_get_an_event_loop().run_until_complete(self.aclear_cache(modes))
1904
 
1905
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
1906
  async def aedit_entity(
1907
  self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
1908
  ) -> dict[str, Any]:
 
2115
  ]
2116
  )
2117
 
2118
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
2119
  async def aedit_relation(
2120
  self, source_entity: str, target_entity: str, updated_data: dict[str, Any]
2121
  ) -> dict[str, Any]:
 
2430
  self.acreate_relation(source_entity, target_entity, relation_data)
2431
  )
2432
 
2433
+ # TODO: Lock all KG relative DB to esure consistency across multiple processes
2434
  async def amerge_entities(
2435
  self,
2436
  source_entities: list[str],
lightrag/llm/openai.py CHANGED
@@ -44,6 +44,47 @@ class InvalidResponseError(Exception):
44
  pass
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @retry(
48
  stop=stop_after_attempt(3),
49
  wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -61,29 +102,52 @@ async def openai_complete_if_cache(
61
  token_tracker: Any | None = None,
62
  **kwargs: Any,
63
  ) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  if history_messages is None:
65
  history_messages = []
66
- if not api_key:
67
- api_key = os.environ["OPENAI_API_KEY"]
68
-
69
- default_headers = {
70
- "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
71
- "Content-Type": "application/json",
72
- }
73
 
74
  # Set openai logger level to INFO when VERBOSE_DEBUG is off
75
  if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
76
  logging.getLogger("openai").setLevel(logging.INFO)
77
 
78
- openai_async_client = (
79
- AsyncOpenAI(default_headers=default_headers, api_key=api_key)
80
- if base_url is None
81
- else AsyncOpenAI(
82
- base_url=base_url, default_headers=default_headers, api_key=api_key
83
- )
84
  )
 
 
85
  kwargs.pop("hashing_kv", None)
86
  kwargs.pop("keyword_extraction", None)
 
 
87
  messages: list[dict[str, Any]] = []
88
  if system_prompt:
89
  messages.append({"role": "system", "content": system_prompt})
@@ -272,21 +336,32 @@ async def openai_embed(
272
  model: str = "text-embedding-3-small",
273
  base_url: str = None,
274
  api_key: str = None,
 
275
  ) -> np.ndarray:
276
- if not api_key:
277
- api_key = os.environ["OPENAI_API_KEY"]
278
-
279
- default_headers = {
280
- "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
281
- "Content-Type": "application/json",
282
- }
283
- openai_async_client = (
284
- AsyncOpenAI(default_headers=default_headers, api_key=api_key)
285
- if base_url is None
286
- else AsyncOpenAI(
287
- base_url=base_url, default_headers=default_headers, api_key=api_key
288
- )
 
 
 
 
 
 
 
 
 
289
  )
 
290
  response = await openai_async_client.embeddings.create(
291
  model=model, input=texts, encoding_format="float"
292
  )
 
44
  pass
45
 
46
 
47
+ def create_openai_async_client(
48
+ api_key: str | None = None,
49
+ base_url: str | None = None,
50
+ client_configs: dict[str, Any] = None,
51
+ ) -> AsyncOpenAI:
52
+ """Create an AsyncOpenAI client with the given configuration.
53
+
54
+ Args:
55
+ api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
56
+ base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
57
+ client_configs: Additional configuration options for the AsyncOpenAI client.
58
+ These will override any default configurations but will be overridden by
59
+ explicit parameters (api_key, base_url).
60
+
61
+ Returns:
62
+ An AsyncOpenAI client instance.
63
+ """
64
+ if not api_key:
65
+ api_key = os.environ["OPENAI_API_KEY"]
66
+
67
+ default_headers = {
68
+ "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
69
+ "Content-Type": "application/json",
70
+ }
71
+
72
+ if client_configs is None:
73
+ client_configs = {}
74
+
75
+ # Create a merged config dict with precedence: explicit params > client_configs > defaults
76
+ merged_configs = {
77
+ **client_configs,
78
+ "default_headers": default_headers,
79
+ "api_key": api_key,
80
+ }
81
+
82
+ if base_url is not None:
83
+ merged_configs["base_url"] = base_url
84
+
85
+ return AsyncOpenAI(**merged_configs)
86
+
87
+
88
  @retry(
89
  stop=stop_after_attempt(3),
90
  wait=wait_exponential(multiplier=1, min=4, max=10),
 
102
  token_tracker: Any | None = None,
103
  **kwargs: Any,
104
  ) -> str:
105
+ """Complete a prompt using OpenAI's API with caching support.
106
+
107
+ Args:
108
+ model: The OpenAI model to use.
109
+ prompt: The prompt to complete.
110
+ system_prompt: Optional system prompt to include.
111
+ history_messages: Optional list of previous messages in the conversation.
112
+ base_url: Optional base URL for the OpenAI API.
113
+ api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
114
+ **kwargs: Additional keyword arguments to pass to the OpenAI API.
115
+ Special kwargs:
116
+ - openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
117
+ These will be passed to the client constructor but will be overridden by
118
+ explicit parameters (api_key, base_url).
119
+ - hashing_kv: Will be removed from kwargs before passing to OpenAI.
120
+ - keyword_extraction: Will be removed from kwargs before passing to OpenAI.
121
+
122
+ Returns:
123
+ The completed text or an async iterator of text chunks if streaming.
124
+
125
+ Raises:
126
+ InvalidResponseError: If the response from OpenAI is invalid or empty.
127
+ APIConnectionError: If there is a connection error with the OpenAI API.
128
+ RateLimitError: If the OpenAI API rate limit is exceeded.
129
+ APITimeoutError: If the OpenAI API request times out.
130
+ """
131
  if history_messages is None:
132
  history_messages = []
 
 
 
 
 
 
 
133
 
134
  # Set openai logger level to INFO when VERBOSE_DEBUG is off
135
  if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
136
  logging.getLogger("openai").setLevel(logging.INFO)
137
 
138
+ # Extract client configuration options
139
+ client_configs = kwargs.pop("openai_client_configs", {})
140
+
141
+ # Create the OpenAI client
142
+ openai_async_client = create_openai_async_client(
143
+ api_key=api_key, base_url=base_url, client_configs=client_configs
144
  )
145
+
146
+ # Remove special kwargs that shouldn't be passed to OpenAI
147
  kwargs.pop("hashing_kv", None)
148
  kwargs.pop("keyword_extraction", None)
149
+
150
+ # Prepare messages
151
  messages: list[dict[str, Any]] = []
152
  if system_prompt:
153
  messages.append({"role": "system", "content": system_prompt})
 
336
  model: str = "text-embedding-3-small",
337
  base_url: str = None,
338
  api_key: str = None,
339
+ client_configs: dict[str, Any] = None,
340
  ) -> np.ndarray:
341
+ """Generate embeddings for a list of texts using OpenAI's API.
342
+
343
+ Args:
344
+ texts: List of texts to embed.
345
+ model: The OpenAI embedding model to use.
346
+ base_url: Optional base URL for the OpenAI API.
347
+ api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
348
+ client_configs: Additional configuration options for the AsyncOpenAI client.
349
+ These will override any default configurations but will be overridden by
350
+ explicit parameters (api_key, base_url).
351
+
352
+ Returns:
353
+ A numpy array of embeddings, one per input text.
354
+
355
+ Raises:
356
+ APIConnectionError: If there is a connection error with the OpenAI API.
357
+ RateLimitError: If the OpenAI API rate limit is exceeded.
358
+ APITimeoutError: If the OpenAI API request times out.
359
+ """
360
+ # Create the OpenAI client
361
+ openai_async_client = create_openai_async_client(
362
+ api_key=api_key, base_url=base_url, client_configs=client_configs
363
  )
364
+
365
  response = await openai_async_client.embeddings.create(
366
  model=model, input=texts, encoding_format="float"
367
  )
lightrag/operate.py CHANGED
@@ -26,7 +26,6 @@ from .utils import (
26
  CacheData,
27
  statistic_data,
28
  get_conversation_turns,
29
- verbose_debug,
30
  )
31
  from .base import (
32
  BaseGraphStorage,
@@ -442,6 +441,13 @@ async def extract_entities(
442
 
443
  processed_chunks = 0
444
  total_chunks = len(ordered_chunks)
 
 
 
 
 
 
 
445
 
446
  async def _user_llm_func_with_cache(
447
  input_text: str, history_messages: list[dict[str, str]] = None
@@ -540,7 +546,7 @@ async def extract_entities(
540
  chunk_key_dp (tuple[str, TextChunkSchema]):
541
  ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
542
  """
543
- nonlocal processed_chunks
544
  chunk_key = chunk_key_dp[0]
545
  chunk_dp = chunk_key_dp[1]
546
  content = chunk_dp["content"]
@@ -598,102 +604,74 @@ async def extract_entities(
598
  async with pipeline_status_lock:
599
  pipeline_status["latest_message"] = log_message
600
  pipeline_status["history_messages"].append(log_message)
601
- return dict(maybe_nodes), dict(maybe_edges)
602
 
603
- tasks = [_process_single_content(c) for c in ordered_chunks]
604
- results = await asyncio.gather(*tasks)
 
605
 
606
- maybe_nodes = defaultdict(list)
607
- maybe_edges = defaultdict(list)
608
- for m_nodes, m_edges in results:
609
- for k, v in m_nodes.items():
610
- maybe_nodes[k].extend(v)
611
- for k, v in m_edges.items():
612
- maybe_edges[tuple(sorted(k))].extend(v)
613
-
614
- from .kg.shared_storage import get_graph_db_lock
615
-
616
- graph_db_lock = get_graph_db_lock(enable_logging=False)
617
-
618
- # Ensure that nodes and edges are merged and upserted atomically
619
- async with graph_db_lock:
620
- all_entities_data = await asyncio.gather(
621
- *[
622
- _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
623
- for k, v in maybe_nodes.items()
624
- ]
625
- )
626
-
627
- all_relationships_data = await asyncio.gather(
628
- *[
629
- _merge_edges_then_upsert(
630
- k[0], k[1], v, knowledge_graph_inst, global_config
631
  )
632
- for k, v in maybe_edges.items()
633
- ]
634
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
 
636
- if not (all_entities_data or all_relationships_data):
637
- log_message = "Didn't extract any entities and relationships."
638
- logger.info(log_message)
639
- if pipeline_status is not None:
640
- async with pipeline_status_lock:
641
- pipeline_status["latest_message"] = log_message
642
- pipeline_status["history_messages"].append(log_message)
643
- return
644
 
645
- if not all_entities_data:
646
- log_message = "Didn't extract any entities"
647
- logger.info(log_message)
648
- if pipeline_status is not None:
649
- async with pipeline_status_lock:
650
- pipeline_status["latest_message"] = log_message
651
- pipeline_status["history_messages"].append(log_message)
652
- if not all_relationships_data:
653
- log_message = "Didn't extract any relationships"
654
- logger.info(log_message)
655
- if pipeline_status is not None:
656
- async with pipeline_status_lock:
657
- pipeline_status["latest_message"] = log_message
658
- pipeline_status["history_messages"].append(log_message)
659
 
660
- log_message = f"Extracted {len(all_entities_data)} entities + {len(all_relationships_data)} relationships (deduplicated)"
661
  logger.info(log_message)
662
  if pipeline_status is not None:
663
  async with pipeline_status_lock:
664
  pipeline_status["latest_message"] = log_message
665
  pipeline_status["history_messages"].append(log_message)
666
- verbose_debug(
667
- f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
668
- )
669
- verbose_debug(f"New relationships:{all_relationships_data}")
670
-
671
- if entity_vdb is not None:
672
- data_for_vdb = {
673
- compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
674
- "entity_name": dp["entity_name"],
675
- "entity_type": dp["entity_type"],
676
- "content": f"{dp['entity_name']}\n{dp['description']}",
677
- "source_id": dp["source_id"],
678
- "file_path": dp.get("file_path", "unknown_source"),
679
- }
680
- for dp in all_entities_data
681
- }
682
- await entity_vdb.upsert(data_for_vdb)
683
-
684
- if relationships_vdb is not None:
685
- data_for_vdb = {
686
- compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
687
- "src_id": dp["src_id"],
688
- "tgt_id": dp["tgt_id"],
689
- "keywords": dp["keywords"],
690
- "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
691
- "source_id": dp["source_id"],
692
- "file_path": dp.get("file_path", "unknown_source"),
693
- }
694
- for dp in all_relationships_data
695
- }
696
- await relationships_vdb.upsert(data_for_vdb)
697
 
698
 
699
  async def kg_query(
@@ -720,8 +698,7 @@ async def kg_query(
720
  if cached_response is not None:
721
  return cached_response
722
 
723
- # Extract keywords using extract_keywords_only function which already supports conversation history
724
- hl_keywords, ll_keywords = await extract_keywords_only(
725
  query, query_param, global_config, hashing_kv
726
  )
727
 
@@ -817,6 +794,38 @@ async def kg_query(
817
  return response
818
 
819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
820
  async def extract_keywords_only(
821
  text: str,
822
  param: QueryParam,
@@ -957,8 +966,7 @@ async def mix_kg_vector_query(
957
  # 2. Execute knowledge graph and vector searches in parallel
958
  async def get_kg_context():
959
  try:
960
- # Extract keywords using extract_keywords_only function which already supports conversation history
961
- hl_keywords, ll_keywords = await extract_keywords_only(
962
  query, query_param, global_config, hashing_kv
963
  )
964
 
@@ -1339,7 +1347,9 @@ async def _get_node_data(
1339
 
1340
  text_units_section_list = [["id", "content", "file_path"]]
1341
  for i, t in enumerate(use_text_units):
1342
- text_units_section_list.append([i, t["content"], t["file_path"]])
 
 
1343
  text_units_context = list_of_list_to_csv(text_units_section_list)
1344
  return entities_context, relations_context, text_units_context
1345
 
@@ -2043,16 +2053,13 @@ async def query_with_keywords(
2043
  Query response or async iterator
2044
  """
2045
  # Extract keywords
2046
- hl_keywords, ll_keywords = await extract_keywords_only(
2047
- text=query,
2048
- param=param,
2049
  global_config=global_config,
2050
  hashing_kv=hashing_kv,
2051
  )
2052
 
2053
- param.hl_keywords = hl_keywords
2054
- param.ll_keywords = ll_keywords
2055
-
2056
  # Create a new string with the prompt and the keywords
2057
  ll_keywords_str = ", ".join(ll_keywords)
2058
  hl_keywords_str = ", ".join(hl_keywords)
 
26
  CacheData,
27
  statistic_data,
28
  get_conversation_turns,
 
29
  )
30
  from .base import (
31
  BaseGraphStorage,
 
441
 
442
  processed_chunks = 0
443
  total_chunks = len(ordered_chunks)
444
+ total_entities_count = 0
445
+ total_relations_count = 0
446
+
447
+ # Get lock manager from shared storage
448
+ from .kg.shared_storage import get_graph_db_lock
449
+
450
+ graph_db_lock = get_graph_db_lock(enable_logging=False)
451
 
452
  async def _user_llm_func_with_cache(
453
  input_text: str, history_messages: list[dict[str, str]] = None
 
546
  chunk_key_dp (tuple[str, TextChunkSchema]):
547
  ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
548
  """
549
+ nonlocal processed_chunks, total_entities_count, total_relations_count
550
  chunk_key = chunk_key_dp[0]
551
  chunk_dp = chunk_key_dp[1]
552
  content = chunk_dp["content"]
 
604
  async with pipeline_status_lock:
605
  pipeline_status["latest_message"] = log_message
606
  pipeline_status["history_messages"].append(log_message)
 
607
 
608
+ # Use graph database lock to ensure atomic merges and updates
609
+ chunk_entities_data = []
610
+ chunk_relationships_data = []
611
 
612
+ async with graph_db_lock:
613
+ # Process and update entities
614
+ for entity_name, entities in maybe_nodes.items():
615
+ entity_data = await _merge_nodes_then_upsert(
616
+ entity_name, entities, knowledge_graph_inst, global_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  )
618
+ chunk_entities_data.append(entity_data)
619
+
620
+ # Process and update relationships
621
+ for edge_key, edges in maybe_edges.items():
622
+ # Ensure edge direction consistency
623
+ sorted_edge_key = tuple(sorted(edge_key))
624
+ edge_data = await _merge_edges_then_upsert(
625
+ sorted_edge_key[0],
626
+ sorted_edge_key[1],
627
+ edges,
628
+ knowledge_graph_inst,
629
+ global_config,
630
+ )
631
+ chunk_relationships_data.append(edge_data)
632
+
633
+ # Update vector database (within the same lock to ensure atomicity)
634
+ if entity_vdb is not None and chunk_entities_data:
635
+ data_for_vdb = {
636
+ compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
637
+ "entity_name": dp["entity_name"],
638
+ "entity_type": dp["entity_type"],
639
+ "content": f"{dp['entity_name']}\n{dp['description']}",
640
+ "source_id": dp["source_id"],
641
+ "file_path": dp.get("file_path", "unknown_source"),
642
+ }
643
+ for dp in chunk_entities_data
644
+ }
645
+ await entity_vdb.upsert(data_for_vdb)
646
+
647
+ if relationships_vdb is not None and chunk_relationships_data:
648
+ data_for_vdb = {
649
+ compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
650
+ "src_id": dp["src_id"],
651
+ "tgt_id": dp["tgt_id"],
652
+ "keywords": dp["keywords"],
653
+ "content": f"{dp['src_id']}\t{dp['tgt_id']}\n{dp['keywords']}\n{dp['description']}",
654
+ "source_id": dp["source_id"],
655
+ "file_path": dp.get("file_path", "unknown_source"),
656
+ }
657
+ for dp in chunk_relationships_data
658
+ }
659
+ await relationships_vdb.upsert(data_for_vdb)
660
 
661
+ # Update counters
662
+ total_entities_count += len(chunk_entities_data)
663
+ total_relations_count += len(chunk_relationships_data)
 
 
 
 
 
664
 
665
+ # Handle all chunks in parallel
666
+ tasks = [_process_single_content(c) for c in ordered_chunks]
667
+ await asyncio.gather(*tasks)
 
 
 
 
 
 
 
 
 
 
 
668
 
669
+ log_message = f"Extracted {total_entities_count} entities + {total_relations_count} relationships (total)"
670
  logger.info(log_message)
671
  if pipeline_status is not None:
672
  async with pipeline_status_lock:
673
  pipeline_status["latest_message"] = log_message
674
  pipeline_status["history_messages"].append(log_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
 
676
 
677
  async def kg_query(
 
698
  if cached_response is not None:
699
  return cached_response
700
 
701
+ hl_keywords, ll_keywords = await get_keywords_from_query(
 
702
  query, query_param, global_config, hashing_kv
703
  )
704
 
 
794
  return response
795
 
796
 
797
+ async def get_keywords_from_query(
798
+ query: str,
799
+ query_param: QueryParam,
800
+ global_config: dict[str, str],
801
+ hashing_kv: BaseKVStorage | None = None,
802
+ ) -> tuple[list[str], list[str]]:
803
+ """
804
+ Retrieves high-level and low-level keywords for RAG operations.
805
+
806
+ This function checks if keywords are already provided in query parameters,
807
+ and if not, extracts them from the query text using LLM.
808
+
809
+ Args:
810
+ query: The user's query text
811
+ query_param: Query parameters that may contain pre-defined keywords
812
+ global_config: Global configuration dictionary
813
+ hashing_kv: Optional key-value storage for caching results
814
+
815
+ Returns:
816
+ A tuple containing (high_level_keywords, low_level_keywords)
817
+ """
818
+ # Check if pre-defined keywords are already provided
819
+ if query_param.hl_keywords or query_param.ll_keywords:
820
+ return query_param.hl_keywords, query_param.ll_keywords
821
+
822
+ # Extract keywords using extract_keywords_only function which already supports conversation history
823
+ hl_keywords, ll_keywords = await extract_keywords_only(
824
+ query, query_param, global_config, hashing_kv
825
+ )
826
+ return hl_keywords, ll_keywords
827
+
828
+
829
  async def extract_keywords_only(
830
  text: str,
831
  param: QueryParam,
 
966
  # 2. Execute knowledge graph and vector searches in parallel
967
  async def get_kg_context():
968
  try:
969
+ hl_keywords, ll_keywords = await get_keywords_from_query(
 
970
  query, query_param, global_config, hashing_kv
971
  )
972
 
 
1347
 
1348
  text_units_section_list = [["id", "content", "file_path"]]
1349
  for i, t in enumerate(use_text_units):
1350
+ text_units_section_list.append(
1351
+ [i, t["content"], t.get("file_path", "unknown_source")]
1352
+ )
1353
  text_units_context = list_of_list_to_csv(text_units_section_list)
1354
  return entities_context, relations_context, text_units_context
1355
 
 
2053
  Query response or async iterator
2054
  """
2055
  # Extract keywords
2056
+ hl_keywords, ll_keywords = await get_keywords_from_query(
2057
+ query=query,
2058
+ query_param=param,
2059
  global_config=global_config,
2060
  hashing_kv=hashing_kv,
2061
  )
2062
 
 
 
 
2063
  # Create a new string with the prompt and the keywords
2064
  ll_keywords_str = ", ".join(ll_keywords)
2065
  hl_keywords_str = ", ".join(hl_keywords)
lightrag/types.py CHANGED
@@ -26,3 +26,4 @@ class KnowledgeGraphEdge(BaseModel):
26
  class KnowledgeGraph(BaseModel):
27
  nodes: list[KnowledgeGraphNode] = []
28
  edges: list[KnowledgeGraphEdge] = []
 
 
26
  class KnowledgeGraph(BaseModel):
27
  nodes: list[KnowledgeGraphNode] = []
28
  edges: list[KnowledgeGraphEdge] = []
29
+ is_truncated: bool = False
lightrag_webui/src/App.tsx CHANGED
@@ -3,12 +3,13 @@ import ThemeProvider from '@/components/ThemeProvider'
3
  import TabVisibilityProvider from '@/contexts/TabVisibilityProvider'
4
  import ApiKeyAlert from '@/components/ApiKeyAlert'
5
  import StatusIndicator from '@/components/status/StatusIndicator'
6
- import { healthCheckInterval } from '@/lib/constants'
7
  import { useBackendState, useAuthStore } from '@/stores/state'
8
  import { useSettingsStore } from '@/stores/settings'
9
  import { getAuthStatus } from '@/api/lightrag'
10
  import SiteHeader from '@/features/SiteHeader'
11
  import { InvalidApiKeyError, RequireApiKeError } from '@/api/lightrag'
 
12
 
13
  import GraphViewer from '@/features/GraphViewer'
14
  import DocumentManager from '@/features/DocumentManager'
@@ -22,6 +23,7 @@ function App() {
22
  const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
23
  const currentTab = useSettingsStore.use.currentTab()
24
  const [apiKeyAlertOpen, setApiKeyAlertOpen] = useState(false)
 
25
  const versionCheckRef = useRef(false); // Prevent duplicate calls in Vite dev mode
26
 
27
  const handleApiKeyAlertOpenChange = useCallback((open: boolean) => {
@@ -55,29 +57,48 @@ function App() {
55
 
56
  // Check if version info was already obtained in login page
57
  const versionCheckedFromLogin = sessionStorage.getItem('VERSION_CHECKED_FROM_LOGIN') === 'true';
58
- if (versionCheckedFromLogin) return;
59
-
60
- // Get version info
61
- const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
62
- if (!token) return;
63
 
64
  try {
 
 
 
 
65
  const status = await getAuthStatus();
66
- if (status.core_version || status.api_version) {
 
 
 
 
 
 
 
 
 
 
 
 
67
  const isGuestMode = status.auth_mode === 'disabled' || useAuthStore.getState().isGuestMode;
68
- // Update version info while maintaining login state
69
  useAuthStore.getState().login(
70
  token,
71
  isGuestMode,
72
  status.core_version,
73
- status.api_version
 
 
74
  );
75
-
76
- // Set flag to indicate version info has been checked
77
- sessionStorage.setItem('VERSION_CHECKED_FROM_LOGIN', 'true');
78
  }
 
 
 
79
  } catch (error) {
80
  console.error('Failed to get version info:', error);
 
 
 
81
  }
82
  };
83
 
@@ -101,31 +122,63 @@ function App() {
101
  return (
102
  <ThemeProvider>
103
  <TabVisibilityProvider>
104
- <main className="flex h-screen w-screen overflow-hidden">
105
- <Tabs
106
- defaultValue={currentTab}
107
- className="!m-0 flex grow flex-col !p-0 overflow-hidden"
108
- onValueChange={handleTabChange}
109
- >
110
- <SiteHeader />
111
- <div className="relative grow">
112
- <TabsContent value="documents" className="absolute top-0 right-0 bottom-0 left-0 overflow-auto">
113
- <DocumentManager />
114
- </TabsContent>
115
- <TabsContent value="knowledge-graph" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
116
- <GraphViewer />
117
- </TabsContent>
118
- <TabsContent value="retrieval" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
119
- <RetrievalTesting />
120
- </TabsContent>
121
- <TabsContent value="api" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
122
- <ApiSite />
123
- </TabsContent>
 
 
 
 
 
 
 
124
  </div>
125
- </Tabs>
126
- {enableHealthCheck && <StatusIndicator />}
127
- <ApiKeyAlert open={apiKeyAlertOpen} onOpenChange={handleApiKeyAlertOpenChange} />
128
- </main>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  </TabVisibilityProvider>
130
  </ThemeProvider>
131
  )
 
3
  import TabVisibilityProvider from '@/contexts/TabVisibilityProvider'
4
  import ApiKeyAlert from '@/components/ApiKeyAlert'
5
  import StatusIndicator from '@/components/status/StatusIndicator'
6
+ import { healthCheckInterval, SiteInfo, webuiPrefix } from '@/lib/constants'
7
  import { useBackendState, useAuthStore } from '@/stores/state'
8
  import { useSettingsStore } from '@/stores/settings'
9
  import { getAuthStatus } from '@/api/lightrag'
10
  import SiteHeader from '@/features/SiteHeader'
11
  import { InvalidApiKeyError, RequireApiKeError } from '@/api/lightrag'
12
+ import { ZapIcon } from 'lucide-react'
13
 
14
  import GraphViewer from '@/features/GraphViewer'
15
  import DocumentManager from '@/features/DocumentManager'
 
23
  const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
24
  const currentTab = useSettingsStore.use.currentTab()
25
  const [apiKeyAlertOpen, setApiKeyAlertOpen] = useState(false)
26
+ const [initializing, setInitializing] = useState(true) // Add initializing state
27
  const versionCheckRef = useRef(false); // Prevent duplicate calls in Vite dev mode
28
 
29
  const handleApiKeyAlertOpenChange = useCallback((open: boolean) => {
 
57
 
58
  // Check if version info was already obtained in login page
59
  const versionCheckedFromLogin = sessionStorage.getItem('VERSION_CHECKED_FROM_LOGIN') === 'true';
60
+ if (versionCheckedFromLogin) {
61
+ setInitializing(false); // Skip initialization if already checked
62
+ return;
63
+ }
 
64
 
65
  try {
66
+ setInitializing(true); // Start initialization
67
+
68
+ // Get version info
69
+ const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
70
  const status = await getAuthStatus();
71
+
72
+ // If auth is not configured and a new token is returned, use the new token
73
+ if (!status.auth_configured && status.access_token) {
74
+ useAuthStore.getState().login(
75
+ status.access_token, // Use the new token
76
+ true, // Guest mode
77
+ status.core_version,
78
+ status.api_version,
79
+ status.webui_title || null,
80
+ status.webui_description || null
81
+ );
82
+ } else if (token && (status.core_version || status.api_version || status.webui_title || status.webui_description)) {
83
+ // Otherwise use the old token (if it exists)
84
  const isGuestMode = status.auth_mode === 'disabled' || useAuthStore.getState().isGuestMode;
 
85
  useAuthStore.getState().login(
86
  token,
87
  isGuestMode,
88
  status.core_version,
89
+ status.api_version,
90
+ status.webui_title || null,
91
+ status.webui_description || null
92
  );
 
 
 
93
  }
94
+
95
+ // Set flag to indicate version info has been checked
96
+ sessionStorage.setItem('VERSION_CHECKED_FROM_LOGIN', 'true');
97
  } catch (error) {
98
  console.error('Failed to get version info:', error);
99
+ } finally {
100
+ // Ensure initializing is set to false even if there's an error
101
+ setInitializing(false);
102
  }
103
  };
104
 
 
122
  return (
123
  <ThemeProvider>
124
  <TabVisibilityProvider>
125
+ {initializing ? (
126
+ // Loading state while initializing with simplified header
127
+ <div className="flex h-screen w-screen flex-col">
128
+ {/* Simplified header during initialization - matches SiteHeader structure */}
129
+ <header className="border-border/40 bg-background/95 supports-[backdrop-filter]:bg-background/60 sticky top-0 z-50 flex h-10 w-full border-b px-4 backdrop-blur">
130
+ <div className="min-w-[200px] w-auto flex items-center">
131
+ <a href={webuiPrefix} className="flex items-center gap-2">
132
+ <ZapIcon className="size-4 text-emerald-400" aria-hidden="true" />
133
+ <span className="font-bold md:inline-block">{SiteInfo.name}</span>
134
+ </a>
135
+ </div>
136
+
137
+ {/* Empty middle section to maintain layout */}
138
+ <div className="flex h-10 flex-1 items-center justify-center">
139
+ </div>
140
+
141
+ {/* Empty right section to maintain layout */}
142
+ <nav className="w-[200px] flex items-center justify-end">
143
+ </nav>
144
+ </header>
145
+
146
+ {/* Loading indicator in content area */}
147
+ <div className="flex flex-1 items-center justify-center">
148
+ <div className="text-center">
149
+ <div className="mb-2 h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"></div>
150
+ <p>Initializing...</p>
151
+ </div>
152
  </div>
153
+ </div>
154
+ ) : (
155
+ // Main content after initialization
156
+ <main className="flex h-screen w-screen overflow-hidden">
157
+ <Tabs
158
+ defaultValue={currentTab}
159
+ className="!m-0 flex grow flex-col !p-0 overflow-hidden"
160
+ onValueChange={handleTabChange}
161
+ >
162
+ <SiteHeader />
163
+ <div className="relative grow">
164
+ <TabsContent value="documents" className="absolute top-0 right-0 bottom-0 left-0 overflow-auto">
165
+ <DocumentManager />
166
+ </TabsContent>
167
+ <TabsContent value="knowledge-graph" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
168
+ <GraphViewer />
169
+ </TabsContent>
170
+ <TabsContent value="retrieval" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
171
+ <RetrievalTesting />
172
+ </TabsContent>
173
+ <TabsContent value="api" className="absolute top-0 right-0 bottom-0 left-0 overflow-hidden">
174
+ <ApiSite />
175
+ </TabsContent>
176
+ </div>
177
+ </Tabs>
178
+ {enableHealthCheck && <StatusIndicator />}
179
+ <ApiKeyAlert open={apiKeyAlertOpen} onOpenChange={handleApiKeyAlertOpenChange} />
180
+ </main>
181
+ )}
182
  </TabVisibilityProvider>
183
  </ThemeProvider>
184
  )
lightrag_webui/src/AppRouter.tsx CHANGED
@@ -80,7 +80,12 @@ const AppRouter = () => {
80
  <ThemeProvider>
81
  <Router>
82
  <AppContent />
83
- <Toaster position="bottom-center" />
 
 
 
 
 
84
  </Router>
85
  </ThemeProvider>
86
  )
 
80
  <ThemeProvider>
81
  <Router>
82
  <AppContent />
83
+ <Toaster
84
+ position="bottom-center"
85
+ theme="system"
86
+ closeButton
87
+ richColors
88
+ />
89
  </Router>
90
  </ThemeProvider>
91
  )
lightrag_webui/src/api/lightrag.ts CHANGED
@@ -46,6 +46,8 @@ export type LightragStatus = {
46
  api_version?: string
47
  auth_mode?: 'enabled' | 'disabled'
48
  pipeline_busy: boolean
 
 
49
  }
50
 
51
  export type LightragDocumentsScanProgress = {
@@ -140,6 +142,8 @@ export type AuthStatusResponse = {
140
  message?: string
141
  core_version?: string
142
  api_version?: string
 
 
143
  }
144
 
145
  export type PipelineStatusResponse = {
@@ -163,6 +167,8 @@ export type LoginResponse = {
163
  message?: string // Optional message
164
  core_version?: string
165
  api_version?: string
 
 
166
  }
167
 
168
  export const InvalidApiKeyError = 'Invalid API Key'
@@ -221,9 +227,9 @@ axiosInstance.interceptors.response.use(
221
  export const queryGraphs = async (
222
  label: string,
223
  maxDepth: number,
224
- minDegree: number
225
  ): Promise<LightragGraphType> => {
226
- const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&min_degree=${minDegree}`)
227
  return response.data
228
  }
229
 
@@ -382,6 +388,14 @@ export const clearDocuments = async (): Promise<DocActionResponse> => {
382
  return response.data
383
  }
384
 
 
 
 
 
 
 
 
 
385
  export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
386
  try {
387
  // Add a timeout to the request to prevent hanging
 
46
  api_version?: string
47
  auth_mode?: 'enabled' | 'disabled'
48
  pipeline_busy: boolean
49
+ webui_title?: string
50
+ webui_description?: string
51
  }
52
 
53
  export type LightragDocumentsScanProgress = {
 
142
  message?: string
143
  core_version?: string
144
  api_version?: string
145
+ webui_title?: string
146
+ webui_description?: string
147
  }
148
 
149
  export type PipelineStatusResponse = {
 
167
  message?: string // Optional message
168
  core_version?: string
169
  api_version?: string
170
+ webui_title?: string
171
+ webui_description?: string
172
  }
173
 
174
  export const InvalidApiKeyError = 'Invalid API Key'
 
227
  export const queryGraphs = async (
228
  label: string,
229
  maxDepth: number,
230
+ maxNodes: number
231
  ): Promise<LightragGraphType> => {
232
+ const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&max_nodes=${maxNodes}`)
233
  return response.data
234
  }
235
 
 
388
  return response.data
389
  }
390
 
391
+ export const clearCache = async (modes?: string[]): Promise<{
392
+ status: 'success' | 'fail'
393
+ message: string
394
+ }> => {
395
+ const response = await axiosInstance.post('/documents/clear_cache', { modes })
396
+ return response.data
397
+ }
398
+
399
  export const getAuthStatus = async (): Promise<AuthStatusResponse> => {
400
  try {
401
  // Add a timeout to the request to prevent hanging
lightrag_webui/src/components/documents/ClearDocumentsDialog.tsx CHANGED
@@ -1,4 +1,4 @@
1
- import { useState, useCallback } from 'react'
2
  import Button from '@/components/ui/Button'
3
  import {
4
  Dialog,
@@ -6,32 +6,88 @@ import {
6
  DialogDescription,
7
  DialogHeader,
8
  DialogTitle,
9
- DialogTrigger
 
10
  } from '@/components/ui/Dialog'
 
 
11
  import { toast } from 'sonner'
12
  import { errorMessage } from '@/lib/utils'
13
- import { clearDocuments } from '@/api/lightrag'
14
 
15
- import { EraserIcon } from 'lucide-react'
16
  import { useTranslation } from 'react-i18next'
17
 
18
- export default function ClearDocumentsDialog() {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  const { t } = useTranslation()
20
  const [open, setOpen] = useState(false)
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  const handleClear = useCallback(async () => {
 
 
23
  try {
24
  const result = await clearDocuments()
25
- if (result.status === 'success') {
26
- toast.success(t('documentPanel.clearDocuments.success'))
27
- setOpen(false)
28
- } else {
29
  toast.error(t('documentPanel.clearDocuments.failed', { message: result.message }))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  }
 
 
 
31
  } catch (err) {
32
  toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) }))
 
33
  }
34
- }, [setOpen, t])
35
 
36
  return (
37
  <Dialog open={open} onOpenChange={setOpen}>
@@ -42,12 +98,60 @@ export default function ClearDocumentsDialog() {
42
  </DialogTrigger>
43
  <DialogContent className="sm:max-w-xl" onCloseAutoFocus={(e) => e.preventDefault()}>
44
  <DialogHeader>
45
- <DialogTitle>{t('documentPanel.clearDocuments.title')}</DialogTitle>
46
- <DialogDescription>{t('documentPanel.clearDocuments.confirm')}</DialogDescription>
 
 
 
 
 
47
  </DialogHeader>
48
- <Button variant="destructive" onClick={handleClear}>
49
- {t('documentPanel.clearDocuments.confirmButton')}
50
- </Button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  </DialogContent>
52
  </Dialog>
53
  )
 
1
+ import { useState, useCallback, useEffect } from 'react'
2
  import Button from '@/components/ui/Button'
3
  import {
4
  Dialog,
 
6
  DialogDescription,
7
  DialogHeader,
8
  DialogTitle,
9
+ DialogTrigger,
10
+ DialogFooter
11
  } from '@/components/ui/Dialog'
12
+ import Input from '@/components/ui/Input'
13
+ import Checkbox from '@/components/ui/Checkbox'
14
  import { toast } from 'sonner'
15
  import { errorMessage } from '@/lib/utils'
16
+ import { clearDocuments, clearCache } from '@/api/lightrag'
17
 
18
+ import { EraserIcon, AlertTriangleIcon } from 'lucide-react'
19
  import { useTranslation } from 'react-i18next'
20
 
21
+ // 简单的Label组件
22
+ const Label = ({
23
+ htmlFor,
24
+ className,
25
+ children,
26
+ ...props
27
+ }: React.LabelHTMLAttributes<HTMLLabelElement>) => (
28
+ <label
29
+ htmlFor={htmlFor}
30
+ className={className}
31
+ {...props}
32
+ >
33
+ {children}
34
+ </label>
35
+ )
36
+
37
+ interface ClearDocumentsDialogProps {
38
+ onDocumentsCleared?: () => Promise<void>
39
+ }
40
+
41
+ export default function ClearDocumentsDialog({ onDocumentsCleared }: ClearDocumentsDialogProps) {
42
  const { t } = useTranslation()
43
  const [open, setOpen] = useState(false)
44
+ const [confirmText, setConfirmText] = useState('')
45
+ const [clearCacheOption, setClearCacheOption] = useState(false)
46
+ const isConfirmEnabled = confirmText.toLowerCase() === 'yes'
47
+
48
+ // 重置状态当对话框关闭时
49
+ useEffect(() => {
50
+ if (!open) {
51
+ setConfirmText('')
52
+ setClearCacheOption(false)
53
+ }
54
+ }, [open])
55
 
56
  const handleClear = useCallback(async () => {
57
+ if (!isConfirmEnabled) return
58
+
59
  try {
60
  const result = await clearDocuments()
61
+
62
+ if (result.status !== 'success') {
 
 
63
  toast.error(t('documentPanel.clearDocuments.failed', { message: result.message }))
64
+ setConfirmText('')
65
+ return
66
+ }
67
+
68
+ toast.success(t('documentPanel.clearDocuments.success'))
69
+
70
+ if (clearCacheOption) {
71
+ try {
72
+ await clearCache()
73
+ toast.success(t('documentPanel.clearDocuments.cacheCleared'))
74
+ } catch (cacheErr) {
75
+ toast.error(t('documentPanel.clearDocuments.cacheClearFailed', { error: errorMessage(cacheErr) }))
76
+ }
77
+ }
78
+
79
+ // Refresh document list if provided
80
+ if (onDocumentsCleared) {
81
+ onDocumentsCleared().catch(console.error)
82
  }
83
+
84
+ // 所有操作成功后关闭对话框
85
+ setOpen(false)
86
  } catch (err) {
87
  toast.error(t('documentPanel.clearDocuments.error', { error: errorMessage(err) }))
88
+ setConfirmText('')
89
  }
90
+ }, [isConfirmEnabled, clearCacheOption, setOpen, t, onDocumentsCleared])
91
 
92
  return (
93
  <Dialog open={open} onOpenChange={setOpen}>
 
98
  </DialogTrigger>
99
  <DialogContent className="sm:max-w-xl" onCloseAutoFocus={(e) => e.preventDefault()}>
100
  <DialogHeader>
101
+ <DialogTitle className="flex items-center gap-2 text-red-500 dark:text-red-400 font-bold">
102
+ <AlertTriangleIcon className="h-5 w-5" />
103
+ {t('documentPanel.clearDocuments.title')}
104
+ </DialogTitle>
105
+ <DialogDescription className="pt-2">
106
+ {t('documentPanel.clearDocuments.description')}
107
+ </DialogDescription>
108
  </DialogHeader>
109
+
110
+ <div className="text-red-500 dark:text-red-400 font-semibold mb-4">
111
+ {t('documentPanel.clearDocuments.warning')}
112
+ </div>
113
+ <div className="mb-4">
114
+ {t('documentPanel.clearDocuments.confirm')}
115
+ </div>
116
+
117
+ <div className="space-y-4">
118
+ <div className="space-y-2">
119
+ <Label htmlFor="confirm-text" className="text-sm font-medium">
120
+ {t('documentPanel.clearDocuments.confirmPrompt')}
121
+ </Label>
122
+ <Input
123
+ id="confirm-text"
124
+ value={confirmText}
125
+ onChange={(e: React.ChangeEvent<HTMLInputElement>) => setConfirmText(e.target.value)}
126
+ placeholder={t('documentPanel.clearDocuments.confirmPlaceholder')}
127
+ className="w-full"
128
+ />
129
+ </div>
130
+
131
+ <div className="flex items-center space-x-2">
132
+ <Checkbox
133
+ id="clear-cache"
134
+ checked={clearCacheOption}
135
+ onCheckedChange={(checked: boolean | 'indeterminate') => setClearCacheOption(checked === true)}
136
+ />
137
+ <Label htmlFor="clear-cache" className="text-sm font-medium cursor-pointer">
138
+ {t('documentPanel.clearDocuments.clearCache')}
139
+ </Label>
140
+ </div>
141
+ </div>
142
+
143
+ <DialogFooter>
144
+ <Button variant="outline" onClick={() => setOpen(false)}>
145
+ {t('common.cancel')}
146
+ </Button>
147
+ <Button
148
+ variant="destructive"
149
+ onClick={handleClear}
150
+ disabled={!isConfirmEnabled}
151
+ >
152
+ {t('documentPanel.clearDocuments.confirmButton')}
153
+ </Button>
154
+ </DialogFooter>
155
  </DialogContent>
156
  </Dialog>
157
  )