Merge branch 'main' into add-env-settings
Browse files- .dockerignore +63 -1
- .gitignore +19 -15
- README.md +4 -4
- examples/lightrag_oracle_demo.py +18 -1
- examples/test_chromadb.py +49 -23
- external_bindings/OpenWebuiTool/openwebui_tool.py +0 -358
- lightrag/api/README.md +2 -1
- lightrag/base.py +52 -23
- lightrag/exceptions.py +2 -0
- lightrag/kg/chroma_impl.py +41 -27
- lightrag/kg/faiss_impl.py +1 -1
- lightrag/kg/json_kv_impl.py +5 -0
- lightrag/kg/mongo_impl.py +451 -59
- lightrag/kg/nano_vector_db_impl.py +1 -1
- lightrag/kg/neo4j_impl.py +98 -34
- lightrag/lightrag.py +140 -178
- lightrag/llm.py +10 -4
- lightrag/namespace.py +2 -0
- lightrag/operate.py +74 -32
- lightrag/prompt.py +2 -0
- lightrag/types.py +11 -9
- lightrag/utils.py +30 -19
- lightrag_webui/src/components/PropertiesView.tsx +1 -1
- lightrag_webui/src/hooks/useLightragGraph.tsx +10 -2
- lightrag_webui/src/stores/graph.ts +1 -1
.dockerignore
CHANGED
|
@@ -1 +1,63 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-related files and directories
|
| 2 |
+
__pycache__
|
| 3 |
+
.cache
|
| 4 |
+
|
| 5 |
+
# Virtual environment directories
|
| 6 |
+
*.venv
|
| 7 |
+
|
| 8 |
+
# Env
|
| 9 |
+
env/
|
| 10 |
+
*.env*
|
| 11 |
+
.env_example
|
| 12 |
+
|
| 13 |
+
# Distribution / build files
|
| 14 |
+
site
|
| 15 |
+
dist/
|
| 16 |
+
build/
|
| 17 |
+
.eggs/
|
| 18 |
+
*.egg-info/
|
| 19 |
+
*.tgz
|
| 20 |
+
*.tar.gz
|
| 21 |
+
|
| 22 |
+
# Exclude siles and folders
|
| 23 |
+
*.yml
|
| 24 |
+
.dockerignore
|
| 25 |
+
Dockerfile
|
| 26 |
+
Makefile
|
| 27 |
+
|
| 28 |
+
# Exclude other projects
|
| 29 |
+
/tests
|
| 30 |
+
/scripts
|
| 31 |
+
|
| 32 |
+
# Python version manager file
|
| 33 |
+
.python-version
|
| 34 |
+
|
| 35 |
+
# Reports
|
| 36 |
+
*.coverage/
|
| 37 |
+
*.log
|
| 38 |
+
log/
|
| 39 |
+
*.logfire
|
| 40 |
+
|
| 41 |
+
# Cache
|
| 42 |
+
.cache/
|
| 43 |
+
.mypy_cache
|
| 44 |
+
.pytest_cache
|
| 45 |
+
.ruff_cache
|
| 46 |
+
.gradio
|
| 47 |
+
.logfire
|
| 48 |
+
temp/
|
| 49 |
+
|
| 50 |
+
# MacOS-related files
|
| 51 |
+
.DS_Store
|
| 52 |
+
|
| 53 |
+
# VS Code settings (local configuration files)
|
| 54 |
+
.vscode
|
| 55 |
+
|
| 56 |
+
# file
|
| 57 |
+
TODO.md
|
| 58 |
+
|
| 59 |
+
# Exclude Git-related files
|
| 60 |
+
.git
|
| 61 |
+
.github
|
| 62 |
+
.gitignore
|
| 63 |
+
.pre-commit-config.yaml
|
.gitignore
CHANGED
|
@@ -35,23 +35,27 @@ temp/
|
|
| 35 |
|
| 36 |
# IDE / Editor Files
|
| 37 |
.idea/
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
local_neo4jWorkDir/
|
| 41 |
neo4jWorkDir/
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
gui/
|
| 47 |
-
*.log
|
| 48 |
-
.vscode
|
| 49 |
-
inputs
|
| 50 |
-
rag_storage
|
| 51 |
-
.env
|
| 52 |
-
venv/
|
| 53 |
examples/input/
|
| 54 |
examples/output/
|
|
|
|
|
|
|
| 55 |
.DS_Store
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# IDE / Editor Files
|
| 37 |
.idea/
|
| 38 |
+
.vscode/
|
| 39 |
+
.vscode/settings.json
|
| 40 |
+
|
| 41 |
+
# Framework-specific files
|
| 42 |
local_neo4jWorkDir/
|
| 43 |
neo4jWorkDir/
|
| 44 |
+
|
| 45 |
+
# Data & Storage
|
| 46 |
+
inputs/
|
| 47 |
+
rag_storage/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
examples/input/
|
| 49 |
examples/output/
|
| 50 |
+
|
| 51 |
+
# Miscellaneous
|
| 52 |
.DS_Store
|
| 53 |
+
TODO.md
|
| 54 |
+
ignore_this.txt
|
| 55 |
+
*.ignore.*
|
| 56 |
+
|
| 57 |
+
# Project-specific files
|
| 58 |
+
dickens/
|
| 59 |
+
book.txt
|
| 60 |
+
lightrag-dev/
|
| 61 |
+
gui/
|
README.md
CHANGED
|
@@ -237,7 +237,7 @@ rag = LightRAG(
|
|
| 237 |
|
| 238 |
* If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
| 239 |
```python
|
| 240 |
-
from lightrag.llm import hf_model_complete,
|
| 241 |
from transformers import AutoModel, AutoTokenizer
|
| 242 |
from lightrag.utils import EmbeddingFunc
|
| 243 |
|
|
@@ -250,7 +250,7 @@ rag = LightRAG(
|
|
| 250 |
embedding_func=EmbeddingFunc(
|
| 251 |
embedding_dim=384,
|
| 252 |
max_token_size=5000,
|
| 253 |
-
func=lambda texts:
|
| 254 |
texts,
|
| 255 |
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
| 256 |
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
|
@@ -428,9 +428,9 @@ And using a routine to process news documents.
|
|
| 428 |
|
| 429 |
```python
|
| 430 |
rag = LightRAG(..)
|
| 431 |
-
await rag.apipeline_enqueue_documents(
|
| 432 |
# Your routine in loop
|
| 433 |
-
await rag.apipeline_process_enqueue_documents(
|
| 434 |
```
|
| 435 |
|
| 436 |
### Separate Keyword Extraction
|
|
|
|
| 237 |
|
| 238 |
* If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
| 239 |
```python
|
| 240 |
+
from lightrag.llm import hf_model_complete, hf_embed
|
| 241 |
from transformers import AutoModel, AutoTokenizer
|
| 242 |
from lightrag.utils import EmbeddingFunc
|
| 243 |
|
|
|
|
| 250 |
embedding_func=EmbeddingFunc(
|
| 251 |
embedding_dim=384,
|
| 252 |
max_token_size=5000,
|
| 253 |
+
func=lambda texts: hf_embed(
|
| 254 |
texts,
|
| 255 |
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
| 256 |
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
| 428 |
|
| 429 |
```python
|
| 430 |
rag = LightRAG(..)
|
| 431 |
+
await rag.apipeline_enqueue_documents(input)
|
| 432 |
# Your routine in loop
|
| 433 |
+
await rag.apipeline_process_enqueue_documents(input)
|
| 434 |
```
|
| 435 |
|
| 436 |
### Separate Keyword Extraction
|
examples/lightrag_oracle_demo.py
CHANGED
|
@@ -113,7 +113,24 @@ async def main():
|
|
| 113 |
)
|
| 114 |
|
| 115 |
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
# Extract and Insert into LightRAG storage
|
| 119 |
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
|
|
|
| 113 |
)
|
| 114 |
|
| 115 |
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
| 116 |
+
|
| 117 |
+
for storage in [
|
| 118 |
+
rag.vector_db_storage_cls,
|
| 119 |
+
rag.graph_storage_cls,
|
| 120 |
+
rag.doc_status,
|
| 121 |
+
rag.full_docs,
|
| 122 |
+
rag.text_chunks,
|
| 123 |
+
rag.llm_response_cache,
|
| 124 |
+
rag.key_string_value_json_storage_cls,
|
| 125 |
+
rag.chunks_vdb,
|
| 126 |
+
rag.relationships_vdb,
|
| 127 |
+
rag.entities_vdb,
|
| 128 |
+
rag.graph_storage_cls,
|
| 129 |
+
rag.chunk_entity_relation_graph,
|
| 130 |
+
rag.llm_response_cache,
|
| 131 |
+
]:
|
| 132 |
+
# set client
|
| 133 |
+
storage.db = oracle_db
|
| 134 |
|
| 135 |
# Extract and Insert into LightRAG storage
|
| 136 |
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
|
examples/test_chromadb.py
CHANGED
|
@@ -15,6 +15,12 @@ if not os.path.exists(WORKING_DIR):
|
|
| 15 |
os.mkdir(WORKING_DIR)
|
| 16 |
|
| 17 |
# ChromaDB Configuration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
|
| 19 |
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
|
| 20 |
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
|
|
@@ -60,30 +66,50 @@ async def create_embedding_function_instance():
|
|
| 60 |
|
| 61 |
async def initialize_rag():
|
| 62 |
embedding_func_instance = await create_embedding_function_instance()
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
"hnsw:M": 16,
|
| 82 |
-
"hnsw:batch_size": 100,
|
| 83 |
-
"hnsw:sync_threshold": 1000,
|
| 84 |
},
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
|
| 89 |
# Run the initialization
|
|
|
|
| 15 |
os.mkdir(WORKING_DIR)
|
| 16 |
|
| 17 |
# ChromaDB Configuration
|
| 18 |
+
CHROMADB_USE_LOCAL_PERSISTENT = False
|
| 19 |
+
# Local PersistentClient Configuration
|
| 20 |
+
CHROMADB_LOCAL_PATH = os.environ.get(
|
| 21 |
+
"CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")
|
| 22 |
+
)
|
| 23 |
+
# Remote HttpClient Configuration
|
| 24 |
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
|
| 25 |
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
|
| 26 |
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
|
|
|
|
| 66 |
|
| 67 |
async def initialize_rag():
|
| 68 |
embedding_func_instance = await create_embedding_function_instance()
|
| 69 |
+
if CHROMADB_USE_LOCAL_PERSISTENT:
|
| 70 |
+
return LightRAG(
|
| 71 |
+
working_dir=WORKING_DIR,
|
| 72 |
+
llm_model_func=gpt_4o_mini_complete,
|
| 73 |
+
embedding_func=embedding_func_instance,
|
| 74 |
+
vector_storage="ChromaVectorDBStorage",
|
| 75 |
+
log_level="DEBUG",
|
| 76 |
+
embedding_batch_num=32,
|
| 77 |
+
vector_db_storage_cls_kwargs={
|
| 78 |
+
"local_path": CHROMADB_LOCAL_PATH,
|
| 79 |
+
"collection_settings": {
|
| 80 |
+
"hnsw:space": "cosine",
|
| 81 |
+
"hnsw:construction_ef": 128,
|
| 82 |
+
"hnsw:search_ef": 128,
|
| 83 |
+
"hnsw:M": 16,
|
| 84 |
+
"hnsw:batch_size": 100,
|
| 85 |
+
"hnsw:sync_threshold": 1000,
|
| 86 |
+
},
|
|
|
|
|
|
|
|
|
|
| 87 |
},
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
return LightRAG(
|
| 91 |
+
working_dir=WORKING_DIR,
|
| 92 |
+
llm_model_func=gpt_4o_mini_complete,
|
| 93 |
+
embedding_func=embedding_func_instance,
|
| 94 |
+
vector_storage="ChromaVectorDBStorage",
|
| 95 |
+
log_level="DEBUG",
|
| 96 |
+
embedding_batch_num=32,
|
| 97 |
+
vector_db_storage_cls_kwargs={
|
| 98 |
+
"host": CHROMADB_HOST,
|
| 99 |
+
"port": CHROMADB_PORT,
|
| 100 |
+
"auth_token": CHROMADB_AUTH_TOKEN,
|
| 101 |
+
"auth_provider": CHROMADB_AUTH_PROVIDER,
|
| 102 |
+
"auth_header_name": CHROMADB_AUTH_HEADER,
|
| 103 |
+
"collection_settings": {
|
| 104 |
+
"hnsw:space": "cosine",
|
| 105 |
+
"hnsw:construction_ef": 128,
|
| 106 |
+
"hnsw:search_ef": 128,
|
| 107 |
+
"hnsw:M": 16,
|
| 108 |
+
"hnsw:batch_size": 100,
|
| 109 |
+
"hnsw:sync_threshold": 1000,
|
| 110 |
+
},
|
| 111 |
+
},
|
| 112 |
+
)
|
| 113 |
|
| 114 |
|
| 115 |
# Run the initialization
|
external_bindings/OpenWebuiTool/openwebui_tool.py
DELETED
|
@@ -1,358 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
OpenWebui Lightrag Integration Tool
|
| 3 |
-
==================================
|
| 4 |
-
|
| 5 |
-
This tool enables the integration and use of Lightrag within the OpenWebui environment,
|
| 6 |
-
providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
|
| 7 |
-
|
| 8 |
-
Author: ParisNeo (parisneoai@gmail.com)
|
| 9 |
-
Social:
|
| 10 |
-
- Twitter: @ParisNeo_AI
|
| 11 |
-
- Reddit: r/lollms
|
| 12 |
-
- Instagram: https://www.instagram.com/parisneo_ai/
|
| 13 |
-
|
| 14 |
-
License: Apache 2.0
|
| 15 |
-
Copyright (c) 2024-2025 ParisNeo
|
| 16 |
-
|
| 17 |
-
This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
|
| 18 |
-
For more information, visit: https://github.com/ParisNeo/lollms
|
| 19 |
-
|
| 20 |
-
Requirements:
|
| 21 |
-
- Python 3.8+
|
| 22 |
-
- OpenWebui
|
| 23 |
-
- Lightrag
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
# Tool version
|
| 27 |
-
__version__ = "1.0.0"
|
| 28 |
-
__author__ = "ParisNeo"
|
| 29 |
-
__author_email__ = "parisneoai@gmail.com"
|
| 30 |
-
__description__ = "Lightrag integration for OpenWebui"
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
import requests
|
| 34 |
-
import json
|
| 35 |
-
from pydantic import BaseModel, Field
|
| 36 |
-
from typing import Callable, Any, Literal, Union, List, Tuple
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class StatusEventEmitter:
|
| 40 |
-
def __init__(self, event_emitter: Callable[[dict], Any] = None):
|
| 41 |
-
self.event_emitter = event_emitter
|
| 42 |
-
|
| 43 |
-
async def emit(self, description="Unknown State", status="in_progress", done=False):
|
| 44 |
-
if self.event_emitter:
|
| 45 |
-
await self.event_emitter(
|
| 46 |
-
{
|
| 47 |
-
"type": "status",
|
| 48 |
-
"data": {
|
| 49 |
-
"status": status,
|
| 50 |
-
"description": description,
|
| 51 |
-
"done": done,
|
| 52 |
-
},
|
| 53 |
-
}
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
class MessageEventEmitter:
|
| 58 |
-
def __init__(self, event_emitter: Callable[[dict], Any] = None):
|
| 59 |
-
self.event_emitter = event_emitter
|
| 60 |
-
|
| 61 |
-
async def emit(self, content="Some message"):
|
| 62 |
-
if self.event_emitter:
|
| 63 |
-
await self.event_emitter(
|
| 64 |
-
{
|
| 65 |
-
"type": "message",
|
| 66 |
-
"data": {
|
| 67 |
-
"content": content,
|
| 68 |
-
},
|
| 69 |
-
}
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class Tools:
|
| 74 |
-
class Valves(BaseModel):
|
| 75 |
-
LIGHTRAG_SERVER_URL: str = Field(
|
| 76 |
-
default="http://localhost:9621/query",
|
| 77 |
-
description="The base URL for the LightRag server",
|
| 78 |
-
)
|
| 79 |
-
MODE: Literal["naive", "local", "global", "hybrid"] = Field(
|
| 80 |
-
default="hybrid",
|
| 81 |
-
description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
|
| 82 |
-
)
|
| 83 |
-
ONLY_NEED_CONTEXT: bool = Field(
|
| 84 |
-
default=False,
|
| 85 |
-
description="If True, only the context is needed from the LightRag response",
|
| 86 |
-
)
|
| 87 |
-
DEBUG_MODE: bool = Field(
|
| 88 |
-
default=False,
|
| 89 |
-
description="If True, debugging information will be emitted",
|
| 90 |
-
)
|
| 91 |
-
KEY: str = Field(
|
| 92 |
-
default="",
|
| 93 |
-
description="Optional Bearer Key for authentication",
|
| 94 |
-
)
|
| 95 |
-
MAX_ENTITIES: int = Field(
|
| 96 |
-
default=5,
|
| 97 |
-
description="Maximum number of entities to keep",
|
| 98 |
-
)
|
| 99 |
-
MAX_RELATIONSHIPS: int = Field(
|
| 100 |
-
default=5,
|
| 101 |
-
description="Maximum number of relationships to keep",
|
| 102 |
-
)
|
| 103 |
-
MAX_SOURCES: int = Field(
|
| 104 |
-
default=3,
|
| 105 |
-
description="Maximum number of sources to keep",
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
def __init__(self):
|
| 109 |
-
self.valves = self.Valves()
|
| 110 |
-
self.headers = {
|
| 111 |
-
"Content-Type": "application/json",
|
| 112 |
-
"User-Agent": "LightRag-Tool/1.0",
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
async def query_lightrag(
|
| 116 |
-
self,
|
| 117 |
-
query: str,
|
| 118 |
-
__event_emitter__: Callable[[dict], Any] = None,
|
| 119 |
-
) -> str:
|
| 120 |
-
"""
|
| 121 |
-
Query the LightRag server and retrieve information.
|
| 122 |
-
This function must be called before answering the user question
|
| 123 |
-
:params query: The query string to send to the LightRag server.
|
| 124 |
-
:return: The response from the LightRag server in Markdown format or raw response.
|
| 125 |
-
"""
|
| 126 |
-
self.status_emitter = StatusEventEmitter(__event_emitter__)
|
| 127 |
-
self.message_emitter = MessageEventEmitter(__event_emitter__)
|
| 128 |
-
|
| 129 |
-
lightrag_url = self.valves.LIGHTRAG_SERVER_URL
|
| 130 |
-
payload = {
|
| 131 |
-
"query": query,
|
| 132 |
-
"mode": str(self.valves.MODE),
|
| 133 |
-
"stream": False,
|
| 134 |
-
"only_need_context": self.valves.ONLY_NEED_CONTEXT,
|
| 135 |
-
}
|
| 136 |
-
await self.status_emitter.emit("Initializing Lightrag query..")
|
| 137 |
-
|
| 138 |
-
if self.valves.DEBUG_MODE:
|
| 139 |
-
await self.message_emitter.emit(
|
| 140 |
-
"### Debug Mode Active\n\nDebugging information will be displayed.\n"
|
| 141 |
-
)
|
| 142 |
-
await self.message_emitter.emit(
|
| 143 |
-
"#### Payload Sent to LightRag Server\n```json\n"
|
| 144 |
-
+ json.dumps(payload, indent=4)
|
| 145 |
-
+ "\n```\n"
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
# Add Bearer Key to headers if provided
|
| 149 |
-
if self.valves.KEY:
|
| 150 |
-
self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
|
| 151 |
-
|
| 152 |
-
try:
|
| 153 |
-
await self.status_emitter.emit("Sending request to LightRag server")
|
| 154 |
-
|
| 155 |
-
response = requests.post(
|
| 156 |
-
lightrag_url, json=payload, headers=self.headers, timeout=120
|
| 157 |
-
)
|
| 158 |
-
response.raise_for_status()
|
| 159 |
-
data = response.json()
|
| 160 |
-
await self.status_emitter.emit(
|
| 161 |
-
status="complete",
|
| 162 |
-
description="LightRag query Succeeded",
|
| 163 |
-
done=True,
|
| 164 |
-
)
|
| 165 |
-
|
| 166 |
-
# Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
|
| 167 |
-
if self.valves.ONLY_NEED_CONTEXT:
|
| 168 |
-
try:
|
| 169 |
-
if self.valves.DEBUG_MODE:
|
| 170 |
-
await self.message_emitter.emit(
|
| 171 |
-
"#### LightRag Server Response\n```json\n"
|
| 172 |
-
+ data["response"]
|
| 173 |
-
+ "\n```\n"
|
| 174 |
-
)
|
| 175 |
-
except Exception as ex:
|
| 176 |
-
if self.valves.DEBUG_MODE:
|
| 177 |
-
await self.message_emitter.emit(
|
| 178 |
-
"#### Exception\n" + str(ex) + "\n"
|
| 179 |
-
)
|
| 180 |
-
return f"Exception: {ex}"
|
| 181 |
-
return data["response"]
|
| 182 |
-
else:
|
| 183 |
-
if self.valves.DEBUG_MODE:
|
| 184 |
-
await self.message_emitter.emit(
|
| 185 |
-
"#### LightRag Server Response\n```json\n"
|
| 186 |
-
+ data["response"]
|
| 187 |
-
+ "\n```\n"
|
| 188 |
-
)
|
| 189 |
-
await self.status_emitter.emit("Lightrag query success")
|
| 190 |
-
return data["response"]
|
| 191 |
-
|
| 192 |
-
except requests.exceptions.RequestException as e:
|
| 193 |
-
await self.status_emitter.emit(
|
| 194 |
-
status="error",
|
| 195 |
-
description=f"Error during LightRag query: {str(e)}",
|
| 196 |
-
done=True,
|
| 197 |
-
)
|
| 198 |
-
return json.dumps({"error": str(e)})
|
| 199 |
-
|
| 200 |
-
def extract_code_blocks(
|
| 201 |
-
self, text: str, return_remaining_text: bool = False
|
| 202 |
-
) -> Union[List[dict], Tuple[List[dict], str]]:
|
| 203 |
-
"""
|
| 204 |
-
This function extracts code blocks from a given text and optionally returns the text without code blocks.
|
| 205 |
-
|
| 206 |
-
Parameters:
|
| 207 |
-
text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
|
| 208 |
-
return_remaining_text (bool): If True, also returns the text with code blocks removed.
|
| 209 |
-
|
| 210 |
-
Returns:
|
| 211 |
-
Union[List[dict], Tuple[List[dict], str]]:
|
| 212 |
-
- If return_remaining_text is False: Returns only the list of code block dictionaries
|
| 213 |
-
- If return_remaining_text is True: Returns a tuple containing:
|
| 214 |
-
* List of code block dictionaries
|
| 215 |
-
* String containing the text with all code blocks removed
|
| 216 |
-
|
| 217 |
-
Each code block dictionary contains:
|
| 218 |
-
- 'index' (int): The index of the code block in the text
|
| 219 |
-
- 'file_name' (str): The name of the file extracted from the preceding line, if available
|
| 220 |
-
- 'content' (str): The content of the code block
|
| 221 |
-
- 'type' (str): The type of the code block
|
| 222 |
-
- 'is_complete' (bool): True if the block has a closing tag, False otherwise
|
| 223 |
-
"""
|
| 224 |
-
remaining = text
|
| 225 |
-
bloc_index = 0
|
| 226 |
-
first_index = 0
|
| 227 |
-
indices = []
|
| 228 |
-
text_without_blocks = text
|
| 229 |
-
|
| 230 |
-
# Find all code block delimiters
|
| 231 |
-
while len(remaining) > 0:
|
| 232 |
-
try:
|
| 233 |
-
index = remaining.index("```")
|
| 234 |
-
indices.append(index + first_index)
|
| 235 |
-
remaining = remaining[index + 3 :]
|
| 236 |
-
first_index += index + 3
|
| 237 |
-
bloc_index += 1
|
| 238 |
-
except Exception:
|
| 239 |
-
if bloc_index % 2 == 1:
|
| 240 |
-
index = len(remaining)
|
| 241 |
-
indices.append(index)
|
| 242 |
-
remaining = ""
|
| 243 |
-
|
| 244 |
-
code_blocks = []
|
| 245 |
-
is_start = True
|
| 246 |
-
|
| 247 |
-
# Process code blocks and build text without blocks if requested
|
| 248 |
-
if return_remaining_text:
|
| 249 |
-
text_parts = []
|
| 250 |
-
last_end = 0
|
| 251 |
-
|
| 252 |
-
for index, code_delimiter_position in enumerate(indices):
|
| 253 |
-
if is_start:
|
| 254 |
-
block_infos = {
|
| 255 |
-
"index": len(code_blocks),
|
| 256 |
-
"file_name": "",
|
| 257 |
-
"section": "",
|
| 258 |
-
"content": "",
|
| 259 |
-
"type": "",
|
| 260 |
-
"is_complete": False,
|
| 261 |
-
}
|
| 262 |
-
|
| 263 |
-
# Store text before code block if returning remaining text
|
| 264 |
-
if return_remaining_text:
|
| 265 |
-
text_parts.append(text[last_end:code_delimiter_position].strip())
|
| 266 |
-
|
| 267 |
-
# Check the preceding line for file name
|
| 268 |
-
preceding_text = text[:code_delimiter_position].strip().splitlines()
|
| 269 |
-
if preceding_text:
|
| 270 |
-
last_line = preceding_text[-1].strip()
|
| 271 |
-
if last_line.startswith("<file_name>") and last_line.endswith(
|
| 272 |
-
"</file_name>"
|
| 273 |
-
):
|
| 274 |
-
file_name = last_line[
|
| 275 |
-
len("<file_name>") : -len("</file_name>")
|
| 276 |
-
].strip()
|
| 277 |
-
block_infos["file_name"] = file_name
|
| 278 |
-
elif last_line.startswith("## filename:"):
|
| 279 |
-
file_name = last_line[len("## filename:") :].strip()
|
| 280 |
-
block_infos["file_name"] = file_name
|
| 281 |
-
if last_line.startswith("<section>") and last_line.endswith(
|
| 282 |
-
"</section>"
|
| 283 |
-
):
|
| 284 |
-
section = last_line[
|
| 285 |
-
len("<section>") : -len("</section>")
|
| 286 |
-
].strip()
|
| 287 |
-
block_infos["section"] = section
|
| 288 |
-
|
| 289 |
-
sub_text = text[code_delimiter_position + 3 :]
|
| 290 |
-
if len(sub_text) > 0:
|
| 291 |
-
try:
|
| 292 |
-
find_space = sub_text.index(" ")
|
| 293 |
-
except Exception:
|
| 294 |
-
find_space = int(1e10)
|
| 295 |
-
try:
|
| 296 |
-
find_return = sub_text.index("\n")
|
| 297 |
-
except Exception:
|
| 298 |
-
find_return = int(1e10)
|
| 299 |
-
next_index = min(find_return, find_space)
|
| 300 |
-
if "{" in sub_text[:next_index]:
|
| 301 |
-
next_index = 0
|
| 302 |
-
start_pos = next_index
|
| 303 |
-
|
| 304 |
-
if code_delimiter_position + 3 < len(text) and text[
|
| 305 |
-
code_delimiter_position + 3
|
| 306 |
-
] in ["\n", " ", "\t"]:
|
| 307 |
-
block_infos["type"] = "language-specific"
|
| 308 |
-
else:
|
| 309 |
-
block_infos["type"] = sub_text[:next_index]
|
| 310 |
-
|
| 311 |
-
if index + 1 < len(indices):
|
| 312 |
-
next_pos = indices[index + 1] - code_delimiter_position
|
| 313 |
-
if (
|
| 314 |
-
next_pos - 3 < len(sub_text)
|
| 315 |
-
and sub_text[next_pos - 3] == "`"
|
| 316 |
-
):
|
| 317 |
-
block_infos["content"] = sub_text[
|
| 318 |
-
start_pos : next_pos - 3
|
| 319 |
-
].strip()
|
| 320 |
-
block_infos["is_complete"] = True
|
| 321 |
-
else:
|
| 322 |
-
block_infos["content"] = sub_text[
|
| 323 |
-
start_pos:next_pos
|
| 324 |
-
].strip()
|
| 325 |
-
block_infos["is_complete"] = False
|
| 326 |
-
|
| 327 |
-
if return_remaining_text:
|
| 328 |
-
last_end = indices[index + 1] + 3
|
| 329 |
-
else:
|
| 330 |
-
block_infos["content"] = sub_text[start_pos:].strip()
|
| 331 |
-
block_infos["is_complete"] = False
|
| 332 |
-
|
| 333 |
-
if return_remaining_text:
|
| 334 |
-
last_end = len(text)
|
| 335 |
-
|
| 336 |
-
code_blocks.append(block_infos)
|
| 337 |
-
is_start = False
|
| 338 |
-
else:
|
| 339 |
-
is_start = True
|
| 340 |
-
|
| 341 |
-
if return_remaining_text:
|
| 342 |
-
# Add any remaining text after the last code block
|
| 343 |
-
if last_end < len(text):
|
| 344 |
-
text_parts.append(text[last_end:].strip())
|
| 345 |
-
# Join all non-code parts with newlines
|
| 346 |
-
text_without_blocks = "\n".join(filter(None, text_parts))
|
| 347 |
-
return code_blocks, text_without_blocks
|
| 348 |
-
|
| 349 |
-
return code_blocks
|
| 350 |
-
|
| 351 |
-
def clean(self, csv_content: str):
|
| 352 |
-
lines = csv_content.splitlines()
|
| 353 |
-
if lines:
|
| 354 |
-
# Remove spaces around headers and ensure no spaces between commas
|
| 355 |
-
header = ",".join([col.strip() for col in lines[0].split(",")])
|
| 356 |
-
lines[0] = header # Replace the first line with the cleaned header
|
| 357 |
-
csv_content = "\n".join(lines)
|
| 358 |
-
return csv_content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lightrag/api/README.md
CHANGED
|
@@ -185,7 +185,8 @@ TiDBVectorDBStorage TiDB
|
|
| 185 |
PGVectorStorage Postgres
|
| 186 |
FaissVectorDBStorage Faiss
|
| 187 |
QdrantVectorDBStorage Qdrant
|
| 188 |
-
|
|
|
|
| 189 |
```
|
| 190 |
|
| 191 |
* DOC_STATUS_STORAGE:supported implement-name
|
|
|
|
| 185 |
PGVectorStorage Postgres
|
| 186 |
FaissVectorDBStorage Faiss
|
| 187 |
QdrantVectorDBStorage Qdrant
|
| 188 |
+
OracleVectorDBStorage Oracle
|
| 189 |
+
MongoVectorDBStorage MongoDB
|
| 190 |
```
|
| 191 |
|
| 192 |
* DOC_STATUS_STORAGE:supported implement-name
|
lightrag/base.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from dotenv import load_dotenv
|
| 3 |
from dataclasses import dataclass, field
|
|
@@ -5,10 +7,8 @@ from enum import Enum
|
|
| 5 |
from typing import (
|
| 6 |
Any,
|
| 7 |
Literal,
|
| 8 |
-
Optional,
|
| 9 |
TypedDict,
|
| 10 |
TypeVar,
|
| 11 |
-
Union,
|
| 12 |
)
|
| 13 |
import numpy as np
|
| 14 |
from .utils import EmbeddingFunc
|
|
@@ -72,7 +72,7 @@ class QueryParam:
|
|
| 72 |
ll_keywords: list[str] = field(default_factory=list)
|
| 73 |
"""List of low-level keywords to refine retrieval focus."""
|
| 74 |
|
| 75 |
-
conversation_history: list[dict[str,
|
| 76 |
"""Stores past conversation history to maintain context.
|
| 77 |
Format: [{"role": "user/assistant", "content": "message"}].
|
| 78 |
"""
|
|
@@ -86,19 +86,15 @@ class StorageNameSpace:
|
|
| 86 |
namespace: str
|
| 87 |
global_config: dict[str, Any]
|
| 88 |
|
| 89 |
-
async def index_done_callback(self):
|
| 90 |
"""Commit the storage operations after indexing"""
|
| 91 |
pass
|
| 92 |
|
| 93 |
-
async def query_done_callback(self):
|
| 94 |
-
"""Commit the storage operations after querying"""
|
| 95 |
-
pass
|
| 96 |
-
|
| 97 |
|
| 98 |
@dataclass
|
| 99 |
class BaseVectorStorage(StorageNameSpace):
|
| 100 |
embedding_func: EmbeddingFunc
|
| 101 |
-
meta_fields: set = field(default_factory=set)
|
| 102 |
|
| 103 |
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
| 104 |
raise NotImplementedError
|
|
@@ -109,12 +105,20 @@ class BaseVectorStorage(StorageNameSpace):
|
|
| 109 |
"""
|
| 110 |
raise NotImplementedError
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
@dataclass
|
| 114 |
class BaseKVStorage(StorageNameSpace):
|
| 115 |
-
embedding_func: EmbeddingFunc
|
| 116 |
|
| 117 |
-
async def get_by_id(self, id: str) ->
|
| 118 |
raise NotImplementedError
|
| 119 |
|
| 120 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
|
@@ -133,50 +137,75 @@ class BaseKVStorage(StorageNameSpace):
|
|
| 133 |
|
| 134 |
@dataclass
|
| 135 |
class BaseGraphStorage(StorageNameSpace):
|
| 136 |
-
embedding_func: EmbeddingFunc = None
|
|
|
|
| 137 |
|
| 138 |
async def has_node(self, node_id: str) -> bool:
|
| 139 |
raise NotImplementedError
|
| 140 |
|
|
|
|
|
|
|
| 141 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 142 |
raise NotImplementedError
|
| 143 |
|
|
|
|
|
|
|
| 144 |
async def node_degree(self, node_id: str) -> int:
|
| 145 |
raise NotImplementedError
|
| 146 |
|
|
|
|
|
|
|
| 147 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 148 |
raise NotImplementedError
|
| 149 |
|
| 150 |
-
|
|
|
|
|
|
|
| 151 |
raise NotImplementedError
|
| 152 |
|
|
|
|
|
|
|
| 153 |
async def get_edge(
|
| 154 |
self, source_node_id: str, target_node_id: str
|
| 155 |
-
) ->
|
| 156 |
raise NotImplementedError
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
) ->
|
| 161 |
raise NotImplementedError
|
| 162 |
|
| 163 |
-
|
|
|
|
|
|
|
| 164 |
raise NotImplementedError
|
| 165 |
|
|
|
|
|
|
|
| 166 |
async def upsert_edge(
|
| 167 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 168 |
-
):
|
| 169 |
raise NotImplementedError
|
| 170 |
|
| 171 |
-
|
|
|
|
|
|
|
| 172 |
raise NotImplementedError
|
| 173 |
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
raise NotImplementedError("Node embedding is not used in lightrag.")
|
| 176 |
|
|
|
|
|
|
|
| 177 |
async def get_all_labels(self) -> list[str]:
|
| 178 |
raise NotImplementedError
|
| 179 |
|
|
|
|
|
|
|
| 180 |
async def get_knowledge_graph(
|
| 181 |
self, node_label: str, max_depth: int = 5
|
| 182 |
) -> KnowledgeGraph:
|
|
@@ -208,9 +237,9 @@ class DocProcessingStatus:
|
|
| 208 |
"""ISO format timestamp when document was created"""
|
| 209 |
updated_at: str
|
| 210 |
"""ISO format timestamp when document was last updated"""
|
| 211 |
-
chunks_count:
|
| 212 |
"""Number of chunks after splitting, used for processing"""
|
| 213 |
-
error:
|
| 214 |
"""Error message if failed"""
|
| 215 |
metadata: dict[str, Any] = field(default_factory=dict)
|
| 216 |
"""Additional metadata"""
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
from dotenv import load_dotenv
|
| 5 |
from dataclasses import dataclass, field
|
|
|
|
| 7 |
from typing import (
|
| 8 |
Any,
|
| 9 |
Literal,
|
|
|
|
| 10 |
TypedDict,
|
| 11 |
TypeVar,
|
|
|
|
| 12 |
)
|
| 13 |
import numpy as np
|
| 14 |
from .utils import EmbeddingFunc
|
|
|
|
| 72 |
ll_keywords: list[str] = field(default_factory=list)
|
| 73 |
"""List of low-level keywords to refine retrieval focus."""
|
| 74 |
|
| 75 |
+
conversation_history: list[dict[str, str]] = field(default_factory=list)
|
| 76 |
"""Stores past conversation history to maintain context.
|
| 77 |
Format: [{"role": "user/assistant", "content": "message"}].
|
| 78 |
"""
|
|
|
|
| 86 |
namespace: str
|
| 87 |
global_config: dict[str, Any]
|
| 88 |
|
| 89 |
+
async def index_done_callback(self) -> None:
|
| 90 |
"""Commit the storage operations after indexing"""
|
| 91 |
pass
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
@dataclass
|
| 95 |
class BaseVectorStorage(StorageNameSpace):
|
| 96 |
embedding_func: EmbeddingFunc
|
| 97 |
+
meta_fields: set[str] = field(default_factory=set)
|
| 98 |
|
| 99 |
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
| 100 |
raise NotImplementedError
|
|
|
|
| 105 |
"""
|
| 106 |
raise NotImplementedError
|
| 107 |
|
| 108 |
+
async def delete_entity(self, entity_name: str) -> None:
|
| 109 |
+
"""Delete a single entity by its name"""
|
| 110 |
+
raise NotImplementedError
|
| 111 |
+
|
| 112 |
+
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 113 |
+
"""Delete relations for a given entity by scanning metadata"""
|
| 114 |
+
raise NotImplementedError
|
| 115 |
+
|
| 116 |
|
| 117 |
@dataclass
|
| 118 |
class BaseKVStorage(StorageNameSpace):
|
| 119 |
+
embedding_func: EmbeddingFunc | None = None
|
| 120 |
|
| 121 |
+
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
| 122 |
raise NotImplementedError
|
| 123 |
|
| 124 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
|
|
|
| 137 |
|
| 138 |
@dataclass
|
| 139 |
class BaseGraphStorage(StorageNameSpace):
|
| 140 |
+
embedding_func: EmbeddingFunc | None = None
|
| 141 |
+
"""Check if a node exists in the graph."""
|
| 142 |
|
| 143 |
async def has_node(self, node_id: str) -> bool:
|
| 144 |
raise NotImplementedError
|
| 145 |
|
| 146 |
+
"""Check if an edge exists in the graph."""
|
| 147 |
+
|
| 148 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 149 |
raise NotImplementedError
|
| 150 |
|
| 151 |
+
"""Get the degree of a node."""
|
| 152 |
+
|
| 153 |
async def node_degree(self, node_id: str) -> int:
|
| 154 |
raise NotImplementedError
|
| 155 |
|
| 156 |
+
"""Get the degree of an edge."""
|
| 157 |
+
|
| 158 |
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 159 |
raise NotImplementedError
|
| 160 |
|
| 161 |
+
"""Get a node by its id."""
|
| 162 |
+
|
| 163 |
+
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
| 164 |
raise NotImplementedError
|
| 165 |
|
| 166 |
+
"""Get an edge by its source and target node ids."""
|
| 167 |
+
|
| 168 |
async def get_edge(
|
| 169 |
self, source_node_id: str, target_node_id: str
|
| 170 |
+
) -> dict[str, str] | None:
|
| 171 |
raise NotImplementedError
|
| 172 |
|
| 173 |
+
"""Get all edges connected to a node."""
|
| 174 |
+
|
| 175 |
+
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
| 176 |
raise NotImplementedError
|
| 177 |
|
| 178 |
+
"""Upsert a node into the graph."""
|
| 179 |
+
|
| 180 |
+
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
| 181 |
raise NotImplementedError
|
| 182 |
|
| 183 |
+
"""Upsert an edge into the graph."""
|
| 184 |
+
|
| 185 |
async def upsert_edge(
|
| 186 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 187 |
+
) -> None:
|
| 188 |
raise NotImplementedError
|
| 189 |
|
| 190 |
+
"""Delete a node from the graph."""
|
| 191 |
+
|
| 192 |
+
async def delete_node(self, node_id: str) -> None:
|
| 193 |
raise NotImplementedError
|
| 194 |
|
| 195 |
+
"""Embed nodes using an algorithm."""
|
| 196 |
+
|
| 197 |
+
async def embed_nodes(
|
| 198 |
+
self, algorithm: str
|
| 199 |
+
) -> tuple[np.ndarray[Any, Any], list[str]]:
|
| 200 |
raise NotImplementedError("Node embedding is not used in lightrag.")
|
| 201 |
|
| 202 |
+
"""Get all labels in the graph."""
|
| 203 |
+
|
| 204 |
async def get_all_labels(self) -> list[str]:
|
| 205 |
raise NotImplementedError
|
| 206 |
|
| 207 |
+
"""Get a knowledge graph of a node."""
|
| 208 |
+
|
| 209 |
async def get_knowledge_graph(
|
| 210 |
self, node_label: str, max_depth: int = 5
|
| 211 |
) -> KnowledgeGraph:
|
|
|
|
| 237 |
"""ISO format timestamp when document was created"""
|
| 238 |
updated_at: str
|
| 239 |
"""ISO format timestamp when document was last updated"""
|
| 240 |
+
chunks_count: int | None = None
|
| 241 |
"""Number of chunks after splitting, used for processing"""
|
| 242 |
+
error: str | None = None
|
| 243 |
"""Error message if failed"""
|
| 244 |
metadata: dict[str, Any] = field(default_factory=dict)
|
| 245 |
"""Additional metadata"""
|
lightrag/exceptions.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import httpx
|
| 2 |
from typing import Literal
|
| 3 |
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import httpx
|
| 4 |
from typing import Literal
|
| 5 |
|
lightrag/kg/chroma_impl.py
CHANGED
|
@@ -2,7 +2,7 @@ import asyncio
|
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from typing import Union
|
| 4 |
import numpy as np
|
| 5 |
-
from chromadb import HttpClient
|
| 6 |
from chromadb.config import Settings
|
| 7 |
from lightrag.base import BaseVectorStorage
|
| 8 |
from lightrag.utils import logger
|
|
@@ -49,31 +49,43 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|
| 49 |
**user_collection_settings,
|
| 50 |
}
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
self._collection = self._client.get_or_create_collection(
|
| 79 |
name=self.namespace,
|
|
@@ -144,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|
| 144 |
embedding = await self.embedding_func([query])
|
| 145 |
|
| 146 |
results = self._collection.query(
|
| 147 |
-
query_embeddings=embedding.tolist()
|
|
|
|
|
|
|
| 148 |
n_results=top_k * 2, # Request more results to allow for filtering
|
| 149 |
include=["metadatas", "distances", "documents"],
|
| 150 |
)
|
|
|
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from typing import Union
|
| 4 |
import numpy as np
|
| 5 |
+
from chromadb import HttpClient, PersistentClient
|
| 6 |
from chromadb.config import Settings
|
| 7 |
from lightrag.base import BaseVectorStorage
|
| 8 |
from lightrag.utils import logger
|
|
|
|
| 49 |
**user_collection_settings,
|
| 50 |
}
|
| 51 |
|
| 52 |
+
local_path = config.get("local_path", None)
|
| 53 |
+
if local_path:
|
| 54 |
+
self._client = PersistentClient(
|
| 55 |
+
path=local_path,
|
| 56 |
+
settings=Settings(
|
| 57 |
+
allow_reset=True,
|
| 58 |
+
anonymized_telemetry=False,
|
| 59 |
+
),
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
auth_provider = config.get(
|
| 63 |
+
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
| 64 |
+
)
|
| 65 |
+
auth_credentials = config.get("auth_token", "secret-token")
|
| 66 |
+
headers = {}
|
| 67 |
+
|
| 68 |
+
if "token_authn" in auth_provider:
|
| 69 |
+
headers = {
|
| 70 |
+
config.get(
|
| 71 |
+
"auth_header_name", "X-Chroma-Token"
|
| 72 |
+
): auth_credentials
|
| 73 |
+
}
|
| 74 |
+
elif "basic_authn" in auth_provider:
|
| 75 |
+
auth_credentials = config.get("auth_credentials", "admin:admin")
|
| 76 |
+
|
| 77 |
+
self._client = HttpClient(
|
| 78 |
+
host=config.get("host", "localhost"),
|
| 79 |
+
port=config.get("port", 8000),
|
| 80 |
+
headers=headers,
|
| 81 |
+
settings=Settings(
|
| 82 |
+
chroma_api_impl="rest",
|
| 83 |
+
chroma_client_auth_provider=auth_provider,
|
| 84 |
+
chroma_client_auth_credentials=auth_credentials,
|
| 85 |
+
allow_reset=True,
|
| 86 |
+
anonymized_telemetry=False,
|
| 87 |
+
),
|
| 88 |
+
)
|
| 89 |
|
| 90 |
self._collection = self._client.get_or_create_collection(
|
| 91 |
name=self.namespace,
|
|
|
|
| 156 |
embedding = await self.embedding_func([query])
|
| 157 |
|
| 158 |
results = self._collection.query(
|
| 159 |
+
query_embeddings=embedding.tolist()
|
| 160 |
+
if not isinstance(embedding, list)
|
| 161 |
+
else embedding,
|
| 162 |
n_results=top_k * 2, # Request more results to allow for filtering
|
| 163 |
include=["metadatas", "distances", "documents"],
|
| 164 |
)
|
lightrag/kg/faiss_impl.py
CHANGED
|
@@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|
| 219 |
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
| 220 |
await self.delete([entity_id])
|
| 221 |
|
| 222 |
-
async def delete_entity_relation(self, entity_name: str):
|
| 223 |
"""
|
| 224 |
Delete relations for a given entity by scanning metadata.
|
| 225 |
"""
|
|
|
|
| 219 |
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
| 220 |
await self.delete([entity_id])
|
| 221 |
|
| 222 |
+
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 223 |
"""
|
| 224 |
Delete relations for a given entity by scanning metadata.
|
| 225 |
"""
|
lightrag/kg/json_kv_impl.py
CHANGED
|
@@ -47,3 +47,8 @@ class JsonKVStorage(BaseKVStorage):
|
|
| 47 |
|
| 48 |
async def drop(self) -> None:
|
| 49 |
self._data = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
async def drop(self) -> None:
|
| 49 |
self._data = {}
|
| 50 |
+
|
| 51 |
+
async def delete(self, ids: list[str]) -> None:
|
| 52 |
+
for doc_id in ids:
|
| 53 |
+
self._data.pop(doc_id, None)
|
| 54 |
+
await self.index_done_callback()
|
lightrag/kg/mongo_impl.py
CHANGED
|
@@ -4,6 +4,7 @@ import numpy as np
|
|
| 4 |
import pipmaster as pm
|
| 5 |
import configparser
|
| 6 |
from tqdm.asyncio import tqdm as tqdm_async
|
|
|
|
| 7 |
|
| 8 |
if not pm.is_installed("pymongo"):
|
| 9 |
pm.install("pymongo")
|
|
@@ -14,16 +15,20 @@ if not pm.is_installed("motor"):
|
|
| 14 |
from typing import Any, List, Tuple, Union
|
| 15 |
from motor.motor_asyncio import AsyncIOMotorClient
|
| 16 |
from pymongo import MongoClient
|
|
|
|
|
|
|
| 17 |
|
| 18 |
from ..base import (
|
| 19 |
BaseGraphStorage,
|
| 20 |
BaseKVStorage,
|
|
|
|
| 21 |
DocProcessingStatus,
|
| 22 |
DocStatus,
|
| 23 |
DocStatusStorage,
|
| 24 |
)
|
| 25 |
from ..namespace import NameSpace, is_namespace
|
| 26 |
from ..utils import logger
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
config = configparser.ConfigParser()
|
|
@@ -33,56 +38,66 @@ config.read("config.ini", "utf-8")
|
|
| 33 |
@dataclass
|
| 34 |
class MongoKVStorage(BaseKVStorage):
|
| 35 |
def __post_init__(self):
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
),
|
| 42 |
-
)
|
| 43 |
)
|
|
|
|
| 44 |
database = client.get_database(
|
| 45 |
os.environ.get(
|
| 46 |
"MONGO_DATABASE",
|
| 47 |
config.get("mongodb", "database", fallback="LightRAG"),
|
| 48 |
)
|
| 49 |
)
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
| 54 |
-
return self._data.find_one({"_id": id})
|
| 55 |
|
| 56 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
| 57 |
-
|
|
|
|
| 58 |
|
| 59 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
]
|
| 64 |
-
return set([s for s in data if s not in existing_ids])
|
| 65 |
|
| 66 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 67 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
|
|
|
| 68 |
for mode, items in data.items():
|
| 69 |
-
for k, v in
|
| 70 |
key = f"{mode}_{k}"
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
| 73 |
)
|
| 74 |
-
|
| 75 |
-
logger.debug(f"\nInserted new document with key: {key}")
|
| 76 |
-
data[mode][k]["_id"] = key
|
| 77 |
else:
|
| 78 |
-
|
| 79 |
-
|
| 80 |
data[k]["_id"] = k
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
| 83 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
| 84 |
res = {}
|
| 85 |
-
v = self._data.find_one({"_id": mode + "_" + id})
|
| 86 |
if v:
|
| 87 |
res[id] = v
|
| 88 |
logger.debug(f"llm_response_cache find one by:{id}")
|
|
@@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage):
|
|
| 100 |
@dataclass
|
| 101 |
class MongoDocStatusStorage(DocStatusStorage):
|
| 102 |
def __post_init__(self):
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
)
|
| 106 |
-
|
| 107 |
-
self.
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
| 111 |
-
return self._data.find_one({"_id": id})
|
| 112 |
|
| 113 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
| 114 |
-
|
|
|
|
| 115 |
|
| 116 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
]
|
| 121 |
-
return set([s for s in data if s not in existing_ids])
|
| 122 |
|
| 123 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
|
|
|
| 124 |
for k, v in data.items():
|
| 125 |
-
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
| 126 |
data[k]["_id"] = k
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
async def drop(self) -> None:
|
| 129 |
"""Drop the collection"""
|
|
@@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|
| 132 |
async def get_status_counts(self) -> dict[str, int]:
|
| 133 |
"""Get counts of documents in each status"""
|
| 134 |
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
|
| 135 |
-
|
|
|
|
| 136 |
counts = {}
|
| 137 |
for doc in result:
|
| 138 |
counts[doc["_id"]] = doc["count"]
|
|
@@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage):
|
|
| 142 |
self, status: DocStatus
|
| 143 |
) -> dict[str, DocProcessingStatus]:
|
| 144 |
"""Get all documents by status"""
|
| 145 |
-
|
|
|
|
| 146 |
return {
|
| 147 |
doc["_id"]: DocProcessingStatus(
|
| 148 |
content=doc["content"],
|
|
@@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
| 185 |
global_config=global_config,
|
| 186 |
embedding_func=embedding_func,
|
| 187 |
)
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
),
|
| 194 |
-
)
|
| 195 |
)
|
| 196 |
-
|
|
|
|
| 197 |
os.environ.get(
|
| 198 |
"MONGO_DATABASE",
|
| 199 |
-
|
| 200 |
-
)
|
| 201 |
-
]
|
| 202 |
-
self.collection = self.db[
|
| 203 |
-
os.environ.get(
|
| 204 |
-
"MONGO_KG_COLLECTION",
|
| 205 |
-
config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
|
| 206 |
)
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
#
|
| 210 |
# -------------------------------------------------------------------------
|
|
@@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
| 451 |
self, source_node_id: str
|
| 452 |
) -> Union[List[Tuple[str, str]], None]:
|
| 453 |
"""
|
| 454 |
-
Return a list of (
|
| 455 |
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
| 456 |
"""
|
| 457 |
pipeline = [
|
|
@@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
| 475 |
return None
|
| 476 |
|
| 477 |
edges = result[0].get("edges", [])
|
| 478 |
-
return [(
|
| 479 |
|
| 480 |
#
|
| 481 |
# -------------------------------------------------------------------------
|
|
@@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
| 522 |
|
| 523 |
async def delete_node(self, node_id: str):
|
| 524 |
"""
|
| 525 |
-
1) Remove node
|
| 526 |
2) Remove inbound edges from any doc that references node_id.
|
| 527 |
"""
|
| 528 |
# Remove inbound edges from all other docs
|
|
@@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
| 542 |
Placeholder for demonstration, raises NotImplementedError.
|
| 543 |
"""
|
| 544 |
raise NotImplementedError("Node embedding is not used in lightrag.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import pipmaster as pm
|
| 5 |
import configparser
|
| 6 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 7 |
+
import asyncio
|
| 8 |
|
| 9 |
if not pm.is_installed("pymongo"):
|
| 10 |
pm.install("pymongo")
|
|
|
|
| 15 |
from typing import Any, List, Tuple, Union
|
| 16 |
from motor.motor_asyncio import AsyncIOMotorClient
|
| 17 |
from pymongo import MongoClient
|
| 18 |
+
from pymongo.operations import SearchIndexModel
|
| 19 |
+
from pymongo.errors import PyMongoError
|
| 20 |
|
| 21 |
from ..base import (
|
| 22 |
BaseGraphStorage,
|
| 23 |
BaseKVStorage,
|
| 24 |
+
BaseVectorStorage,
|
| 25 |
DocProcessingStatus,
|
| 26 |
DocStatus,
|
| 27 |
DocStatusStorage,
|
| 28 |
)
|
| 29 |
from ..namespace import NameSpace, is_namespace
|
| 30 |
from ..utils import logger
|
| 31 |
+
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
| 32 |
|
| 33 |
|
| 34 |
config = configparser.ConfigParser()
|
|
|
|
| 38 |
@dataclass
|
| 39 |
class MongoKVStorage(BaseKVStorage):
|
| 40 |
def __post_init__(self):
|
| 41 |
+
uri = os.environ.get(
|
| 42 |
+
"MONGO_URI",
|
| 43 |
+
config.get(
|
| 44 |
+
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
| 45 |
+
),
|
|
|
|
|
|
|
| 46 |
)
|
| 47 |
+
client = AsyncIOMotorClient(uri)
|
| 48 |
database = client.get_database(
|
| 49 |
os.environ.get(
|
| 50 |
"MONGO_DATABASE",
|
| 51 |
config.get("mongodb", "database", fallback="LightRAG"),
|
| 52 |
)
|
| 53 |
)
|
| 54 |
+
|
| 55 |
+
self._collection_name = self.namespace
|
| 56 |
+
|
| 57 |
+
self._data = database.get_collection(self._collection_name)
|
| 58 |
+
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
| 59 |
+
|
| 60 |
+
# Ensure collection exists
|
| 61 |
+
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
| 62 |
|
| 63 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
| 64 |
+
return await self._data.find_one({"_id": id})
|
| 65 |
|
| 66 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
| 67 |
+
cursor = self._data.find({"_id": {"$in": ids}})
|
| 68 |
+
return await cursor.to_list()
|
| 69 |
|
| 70 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
| 71 |
+
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
| 72 |
+
existing_ids = {str(x["_id"]) async for x in cursor}
|
| 73 |
+
return data - existing_ids
|
|
|
|
|
|
|
| 74 |
|
| 75 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 76 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
| 77 |
+
update_tasks = []
|
| 78 |
for mode, items in data.items():
|
| 79 |
+
for k, v in items.items():
|
| 80 |
key = f"{mode}_{k}"
|
| 81 |
+
data[mode][k]["_id"] = f"{mode}_{k}"
|
| 82 |
+
update_tasks.append(
|
| 83 |
+
self._data.update_one(
|
| 84 |
+
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
| 85 |
+
)
|
| 86 |
)
|
| 87 |
+
await asyncio.gather(*update_tasks)
|
|
|
|
|
|
|
| 88 |
else:
|
| 89 |
+
update_tasks = []
|
| 90 |
+
for k, v in data.items():
|
| 91 |
data[k]["_id"] = k
|
| 92 |
+
update_tasks.append(
|
| 93 |
+
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
| 94 |
+
)
|
| 95 |
+
await asyncio.gather(*update_tasks)
|
| 96 |
|
| 97 |
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
| 98 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
| 99 |
res = {}
|
| 100 |
+
v = await self._data.find_one({"_id": mode + "_" + id})
|
| 101 |
if v:
|
| 102 |
res[id] = v
|
| 103 |
logger.debug(f"llm_response_cache find one by:{id}")
|
|
|
|
| 115 |
@dataclass
|
| 116 |
class MongoDocStatusStorage(DocStatusStorage):
|
| 117 |
def __post_init__(self):
|
| 118 |
+
uri = os.environ.get(
|
| 119 |
+
"MONGO_URI",
|
| 120 |
+
config.get(
|
| 121 |
+
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
| 122 |
+
),
|
| 123 |
+
)
|
| 124 |
+
client = AsyncIOMotorClient(uri)
|
| 125 |
+
database = client.get_database(
|
| 126 |
+
os.environ.get(
|
| 127 |
+
"MONGO_DATABASE",
|
| 128 |
+
config.get("mongodb", "database", fallback="LightRAG"),
|
| 129 |
+
)
|
| 130 |
)
|
| 131 |
+
|
| 132 |
+
self._collection_name = self.namespace
|
| 133 |
+
self._data = database.get_collection(self._collection_name)
|
| 134 |
+
|
| 135 |
+
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
|
| 136 |
+
|
| 137 |
+
# Ensure collection exists
|
| 138 |
+
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
| 139 |
|
| 140 |
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
| 141 |
+
return await self._data.find_one({"_id": id})
|
| 142 |
|
| 143 |
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
| 144 |
+
cursor = self._data.find({"_id": {"$in": ids}})
|
| 145 |
+
return await cursor.to_list()
|
| 146 |
|
| 147 |
async def filter_keys(self, data: set[str]) -> set[str]:
|
| 148 |
+
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
|
| 149 |
+
existing_ids = {str(x["_id"]) async for x in cursor}
|
| 150 |
+
return data - existing_ids
|
|
|
|
|
|
|
| 151 |
|
| 152 |
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
| 153 |
+
update_tasks = []
|
| 154 |
for k, v in data.items():
|
|
|
|
| 155 |
data[k]["_id"] = k
|
| 156 |
+
update_tasks.append(
|
| 157 |
+
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
| 158 |
+
)
|
| 159 |
+
await asyncio.gather(*update_tasks)
|
| 160 |
|
| 161 |
async def drop(self) -> None:
|
| 162 |
"""Drop the collection"""
|
|
|
|
| 165 |
async def get_status_counts(self) -> dict[str, int]:
|
| 166 |
"""Get counts of documents in each status"""
|
| 167 |
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
|
| 168 |
+
cursor = self._data.aggregate(pipeline)
|
| 169 |
+
result = await cursor.to_list()
|
| 170 |
counts = {}
|
| 171 |
for doc in result:
|
| 172 |
counts[doc["_id"]] = doc["count"]
|
|
|
|
| 176 |
self, status: DocStatus
|
| 177 |
) -> dict[str, DocProcessingStatus]:
|
| 178 |
"""Get all documents by status"""
|
| 179 |
+
cursor = self._data.find({"status": status.value})
|
| 180 |
+
result = await cursor.to_list()
|
| 181 |
return {
|
| 182 |
doc["_id"]: DocProcessingStatus(
|
| 183 |
content=doc["content"],
|
|
|
|
| 220 |
global_config=global_config,
|
| 221 |
embedding_func=embedding_func,
|
| 222 |
)
|
| 223 |
+
uri = os.environ.get(
|
| 224 |
+
"MONGO_URI",
|
| 225 |
+
config.get(
|
| 226 |
+
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
| 227 |
+
),
|
|
|
|
|
|
|
| 228 |
)
|
| 229 |
+
client = AsyncIOMotorClient(uri)
|
| 230 |
+
database = client.get_database(
|
| 231 |
os.environ.get(
|
| 232 |
"MONGO_DATABASE",
|
| 233 |
+
config.get("mongodb", "database", fallback="LightRAG"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
)
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
self._collection_name = self.namespace
|
| 238 |
+
self.collection = database.get_collection(self._collection_name)
|
| 239 |
+
|
| 240 |
+
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
| 241 |
+
|
| 242 |
+
# Ensure collection exists
|
| 243 |
+
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
| 244 |
|
| 245 |
#
|
| 246 |
# -------------------------------------------------------------------------
|
|
|
|
| 487 |
self, source_node_id: str
|
| 488 |
) -> Union[List[Tuple[str, str]], None]:
|
| 489 |
"""
|
| 490 |
+
Return a list of (source_id, target_id) for direct edges from source_node_id.
|
| 491 |
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
| 492 |
"""
|
| 493 |
pipeline = [
|
|
|
|
| 511 |
return None
|
| 512 |
|
| 513 |
edges = result[0].get("edges", [])
|
| 514 |
+
return [(source_node_id, e["target"]) for e in edges]
|
| 515 |
|
| 516 |
#
|
| 517 |
# -------------------------------------------------------------------------
|
|
|
|
| 558 |
|
| 559 |
async def delete_node(self, node_id: str):
|
| 560 |
"""
|
| 561 |
+
1) Remove node's doc entirely.
|
| 562 |
2) Remove inbound edges from any doc that references node_id.
|
| 563 |
"""
|
| 564 |
# Remove inbound edges from all other docs
|
|
|
|
| 578 |
Placeholder for demonstration, raises NotImplementedError.
|
| 579 |
"""
|
| 580 |
raise NotImplementedError("Node embedding is not used in lightrag.")
|
| 581 |
+
|
| 582 |
+
#
|
| 583 |
+
# -------------------------------------------------------------------------
|
| 584 |
+
# QUERY
|
| 585 |
+
# -------------------------------------------------------------------------
|
| 586 |
+
#
|
| 587 |
+
|
| 588 |
+
async def get_all_labels(self) -> list[str]:
|
| 589 |
+
"""
|
| 590 |
+
Get all existing node _id in the database
|
| 591 |
+
Returns:
|
| 592 |
+
[id1, id2, ...] # Alphabetically sorted id list
|
| 593 |
+
"""
|
| 594 |
+
# Use MongoDB's distinct and aggregation to get all unique labels
|
| 595 |
+
pipeline = [
|
| 596 |
+
{"$group": {"_id": "$_id"}}, # Group by _id
|
| 597 |
+
{"$sort": {"_id": 1}}, # Sort alphabetically
|
| 598 |
+
]
|
| 599 |
+
|
| 600 |
+
cursor = self.collection.aggregate(pipeline)
|
| 601 |
+
labels = []
|
| 602 |
+
async for doc in cursor:
|
| 603 |
+
labels.append(doc["_id"])
|
| 604 |
+
return labels
|
| 605 |
+
|
| 606 |
+
async def get_knowledge_graph(
|
| 607 |
+
self, node_label: str, max_depth: int = 5
|
| 608 |
+
) -> KnowledgeGraph:
|
| 609 |
+
"""
|
| 610 |
+
Get complete connected subgraph for specified node (including the starting node itself)
|
| 611 |
+
|
| 612 |
+
Args:
|
| 613 |
+
node_label: Label of the nodes to start from
|
| 614 |
+
max_depth: Maximum depth of traversal (default: 5)
|
| 615 |
+
|
| 616 |
+
Returns:
|
| 617 |
+
KnowledgeGraph object containing nodes and edges of the subgraph
|
| 618 |
+
"""
|
| 619 |
+
label = node_label
|
| 620 |
+
result = KnowledgeGraph()
|
| 621 |
+
seen_nodes = set()
|
| 622 |
+
seen_edges = set()
|
| 623 |
+
|
| 624 |
+
try:
|
| 625 |
+
if label == "*":
|
| 626 |
+
# Get all nodes and edges
|
| 627 |
+
async for node_doc in self.collection.find({}):
|
| 628 |
+
node_id = str(node_doc["_id"])
|
| 629 |
+
if node_id not in seen_nodes:
|
| 630 |
+
result.nodes.append(
|
| 631 |
+
KnowledgeGraphNode(
|
| 632 |
+
id=node_id,
|
| 633 |
+
labels=[node_doc.get("_id")],
|
| 634 |
+
properties={
|
| 635 |
+
k: v
|
| 636 |
+
for k, v in node_doc.items()
|
| 637 |
+
if k not in ["_id", "edges"]
|
| 638 |
+
},
|
| 639 |
+
)
|
| 640 |
+
)
|
| 641 |
+
seen_nodes.add(node_id)
|
| 642 |
+
|
| 643 |
+
# Process edges
|
| 644 |
+
for edge in node_doc.get("edges", []):
|
| 645 |
+
edge_id = f"{node_id}-{edge['target']}"
|
| 646 |
+
if edge_id not in seen_edges:
|
| 647 |
+
result.edges.append(
|
| 648 |
+
KnowledgeGraphEdge(
|
| 649 |
+
id=edge_id,
|
| 650 |
+
type=edge.get("relation", ""),
|
| 651 |
+
source=node_id,
|
| 652 |
+
target=edge["target"],
|
| 653 |
+
properties={
|
| 654 |
+
k: v
|
| 655 |
+
for k, v in edge.items()
|
| 656 |
+
if k not in ["target", "relation"]
|
| 657 |
+
},
|
| 658 |
+
)
|
| 659 |
+
)
|
| 660 |
+
seen_edges.add(edge_id)
|
| 661 |
+
else:
|
| 662 |
+
# Verify if starting node exists
|
| 663 |
+
start_nodes = self.collection.find({"_id": label})
|
| 664 |
+
start_nodes_exist = await start_nodes.to_list(length=1)
|
| 665 |
+
if not start_nodes_exist:
|
| 666 |
+
logger.warning(f"Starting node with label {label} does not exist!")
|
| 667 |
+
return result
|
| 668 |
+
|
| 669 |
+
# Use $graphLookup for traversal
|
| 670 |
+
pipeline = [
|
| 671 |
+
{
|
| 672 |
+
"$match": {"_id": label}
|
| 673 |
+
}, # Start with nodes having the specified label
|
| 674 |
+
{
|
| 675 |
+
"$graphLookup": {
|
| 676 |
+
"from": self._collection_name,
|
| 677 |
+
"startWith": "$edges.target",
|
| 678 |
+
"connectFromField": "edges.target",
|
| 679 |
+
"connectToField": "_id",
|
| 680 |
+
"maxDepth": max_depth,
|
| 681 |
+
"depthField": "depth",
|
| 682 |
+
"as": "connected_nodes",
|
| 683 |
+
}
|
| 684 |
+
},
|
| 685 |
+
]
|
| 686 |
+
|
| 687 |
+
async for doc in self.collection.aggregate(pipeline):
|
| 688 |
+
# Add the start node
|
| 689 |
+
node_id = str(doc["_id"])
|
| 690 |
+
if node_id not in seen_nodes:
|
| 691 |
+
result.nodes.append(
|
| 692 |
+
KnowledgeGraphNode(
|
| 693 |
+
id=node_id,
|
| 694 |
+
labels=[
|
| 695 |
+
doc.get(
|
| 696 |
+
"_id",
|
| 697 |
+
)
|
| 698 |
+
],
|
| 699 |
+
properties={
|
| 700 |
+
k: v
|
| 701 |
+
for k, v in doc.items()
|
| 702 |
+
if k
|
| 703 |
+
not in [
|
| 704 |
+
"_id",
|
| 705 |
+
"edges",
|
| 706 |
+
"connected_nodes",
|
| 707 |
+
"depth",
|
| 708 |
+
]
|
| 709 |
+
},
|
| 710 |
+
)
|
| 711 |
+
)
|
| 712 |
+
seen_nodes.add(node_id)
|
| 713 |
+
|
| 714 |
+
# Add edges from start node
|
| 715 |
+
for edge in doc.get("edges", []):
|
| 716 |
+
edge_id = f"{node_id}-{edge['target']}"
|
| 717 |
+
if edge_id not in seen_edges:
|
| 718 |
+
result.edges.append(
|
| 719 |
+
KnowledgeGraphEdge(
|
| 720 |
+
id=edge_id,
|
| 721 |
+
type=edge.get("relation", ""),
|
| 722 |
+
source=node_id,
|
| 723 |
+
target=edge["target"],
|
| 724 |
+
properties={
|
| 725 |
+
k: v
|
| 726 |
+
for k, v in edge.items()
|
| 727 |
+
if k not in ["target", "relation"]
|
| 728 |
+
},
|
| 729 |
+
)
|
| 730 |
+
)
|
| 731 |
+
seen_edges.add(edge_id)
|
| 732 |
+
|
| 733 |
+
# Add connected nodes and their edges
|
| 734 |
+
for connected in doc.get("connected_nodes", []):
|
| 735 |
+
node_id = str(connected["_id"])
|
| 736 |
+
if node_id not in seen_nodes:
|
| 737 |
+
result.nodes.append(
|
| 738 |
+
KnowledgeGraphNode(
|
| 739 |
+
id=node_id,
|
| 740 |
+
labels=[connected.get("_id")],
|
| 741 |
+
properties={
|
| 742 |
+
k: v
|
| 743 |
+
for k, v in connected.items()
|
| 744 |
+
if k not in ["_id", "edges", "depth"]
|
| 745 |
+
},
|
| 746 |
+
)
|
| 747 |
+
)
|
| 748 |
+
seen_nodes.add(node_id)
|
| 749 |
+
|
| 750 |
+
# Add edges from connected nodes
|
| 751 |
+
for edge in connected.get("edges", []):
|
| 752 |
+
edge_id = f"{node_id}-{edge['target']}"
|
| 753 |
+
if edge_id not in seen_edges:
|
| 754 |
+
result.edges.append(
|
| 755 |
+
KnowledgeGraphEdge(
|
| 756 |
+
id=edge_id,
|
| 757 |
+
type=edge.get("relation", ""),
|
| 758 |
+
source=node_id,
|
| 759 |
+
target=edge["target"],
|
| 760 |
+
properties={
|
| 761 |
+
k: v
|
| 762 |
+
for k, v in edge.items()
|
| 763 |
+
if k not in ["target", "relation"]
|
| 764 |
+
},
|
| 765 |
+
)
|
| 766 |
+
)
|
| 767 |
+
seen_edges.add(edge_id)
|
| 768 |
+
|
| 769 |
+
logger.info(
|
| 770 |
+
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
except PyMongoError as e:
|
| 774 |
+
logger.error(f"MongoDB query failed: {str(e)}")
|
| 775 |
+
|
| 776 |
+
return result
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
@dataclass
|
| 780 |
+
class MongoVectorDBStorage(BaseVectorStorage):
|
| 781 |
+
cosine_better_than_threshold: float = None
|
| 782 |
+
|
| 783 |
+
def __post_init__(self):
|
| 784 |
+
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
| 785 |
+
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
| 786 |
+
if cosine_threshold is None:
|
| 787 |
+
raise ValueError(
|
| 788 |
+
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
| 789 |
+
)
|
| 790 |
+
self.cosine_better_than_threshold = cosine_threshold
|
| 791 |
+
|
| 792 |
+
uri = os.environ.get(
|
| 793 |
+
"MONGO_URI",
|
| 794 |
+
config.get(
|
| 795 |
+
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
| 796 |
+
),
|
| 797 |
+
)
|
| 798 |
+
client = AsyncIOMotorClient(uri)
|
| 799 |
+
database = client.get_database(
|
| 800 |
+
os.environ.get(
|
| 801 |
+
"MONGO_DATABASE",
|
| 802 |
+
config.get("mongodb", "database", fallback="LightRAG"),
|
| 803 |
+
)
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
self._collection_name = self.namespace
|
| 807 |
+
self._data = database.get_collection(self._collection_name)
|
| 808 |
+
self._max_batch_size = self.global_config["embedding_batch_num"]
|
| 809 |
+
|
| 810 |
+
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
| 811 |
+
|
| 812 |
+
# Ensure collection exists
|
| 813 |
+
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
| 814 |
+
|
| 815 |
+
# Ensure vector index exists
|
| 816 |
+
self.create_vector_index(uri, database.name, self._collection_name)
|
| 817 |
+
|
| 818 |
+
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
|
| 819 |
+
"""Creates an Atlas Vector Search index."""
|
| 820 |
+
client = MongoClient(uri)
|
| 821 |
+
collection = client.get_database(database_name).get_collection(
|
| 822 |
+
self._collection_name
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
try:
|
| 826 |
+
search_index_model = SearchIndexModel(
|
| 827 |
+
definition={
|
| 828 |
+
"fields": [
|
| 829 |
+
{
|
| 830 |
+
"type": "vector",
|
| 831 |
+
"numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions
|
| 832 |
+
"path": "vector",
|
| 833 |
+
"similarity": "cosine", # Options: euclidean, cosine, dotProduct
|
| 834 |
+
}
|
| 835 |
+
]
|
| 836 |
+
},
|
| 837 |
+
name="vector_knn_index",
|
| 838 |
+
type="vectorSearch",
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
collection.create_search_index(search_index_model)
|
| 842 |
+
logger.info("Vector index created successfully.")
|
| 843 |
+
|
| 844 |
+
except PyMongoError as _:
|
| 845 |
+
logger.debug("vector index already exist")
|
| 846 |
+
|
| 847 |
+
async def upsert(self, data: dict[str, dict]):
|
| 848 |
+
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
| 849 |
+
if not data:
|
| 850 |
+
logger.warning("You are inserting an empty data set to vector DB")
|
| 851 |
+
return []
|
| 852 |
+
|
| 853 |
+
list_data = [
|
| 854 |
+
{
|
| 855 |
+
"_id": k,
|
| 856 |
+
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
| 857 |
+
}
|
| 858 |
+
for k, v in data.items()
|
| 859 |
+
]
|
| 860 |
+
contents = [v["content"] for v in data.values()]
|
| 861 |
+
batches = [
|
| 862 |
+
contents[i : i + self._max_batch_size]
|
| 863 |
+
for i in range(0, len(contents), self._max_batch_size)
|
| 864 |
+
]
|
| 865 |
+
|
| 866 |
+
async def wrapped_task(batch):
|
| 867 |
+
result = await self.embedding_func(batch)
|
| 868 |
+
pbar.update(1)
|
| 869 |
+
return result
|
| 870 |
+
|
| 871 |
+
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
| 872 |
+
pbar = tqdm_async(
|
| 873 |
+
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
| 874 |
+
)
|
| 875 |
+
embeddings_list = await asyncio.gather(*embedding_tasks)
|
| 876 |
+
|
| 877 |
+
embeddings = np.concatenate(embeddings_list)
|
| 878 |
+
for i, d in enumerate(list_data):
|
| 879 |
+
d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist()
|
| 880 |
+
|
| 881 |
+
update_tasks = []
|
| 882 |
+
for doc in list_data:
|
| 883 |
+
update_tasks.append(
|
| 884 |
+
self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
|
| 885 |
+
)
|
| 886 |
+
await asyncio.gather(*update_tasks)
|
| 887 |
+
|
| 888 |
+
return list_data
|
| 889 |
+
|
| 890 |
+
async def query(self, query, top_k=5):
|
| 891 |
+
"""Queries the vector database using Atlas Vector Search."""
|
| 892 |
+
# Generate the embedding
|
| 893 |
+
embedding = await self.embedding_func([query])
|
| 894 |
+
|
| 895 |
+
# Convert numpy array to a list to ensure compatibility with MongoDB
|
| 896 |
+
query_vector = embedding[0].tolist()
|
| 897 |
+
|
| 898 |
+
# Define the aggregation pipeline with the converted query vector
|
| 899 |
+
pipeline = [
|
| 900 |
+
{
|
| 901 |
+
"$vectorSearch": {
|
| 902 |
+
"index": "vector_knn_index", # Ensure this matches the created index name
|
| 903 |
+
"path": "vector",
|
| 904 |
+
"queryVector": query_vector,
|
| 905 |
+
"numCandidates": 100, # Adjust for performance
|
| 906 |
+
"limit": top_k,
|
| 907 |
+
}
|
| 908 |
+
},
|
| 909 |
+
{"$addFields": {"score": {"$meta": "vectorSearchScore"}}},
|
| 910 |
+
{"$match": {"score": {"$gte": self.cosine_better_than_threshold}}},
|
| 911 |
+
{"$project": {"vector": 0}},
|
| 912 |
+
]
|
| 913 |
+
|
| 914 |
+
# Execute the aggregation pipeline
|
| 915 |
+
cursor = self._data.aggregate(pipeline)
|
| 916 |
+
results = await cursor.to_list()
|
| 917 |
+
|
| 918 |
+
# Format and return the results
|
| 919 |
+
return [
|
| 920 |
+
{**doc, "id": doc["_id"], "distance": doc.get("score", None)}
|
| 921 |
+
for doc in results
|
| 922 |
+
]
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
| 926 |
+
"""Check if the collection exists. if not, create it."""
|
| 927 |
+
client = MongoClient(uri)
|
| 928 |
+
database = client.get_database(database_name)
|
| 929 |
+
|
| 930 |
+
collection_names = database.list_collection_names()
|
| 931 |
+
|
| 932 |
+
if collection_name not in collection_names:
|
| 933 |
+
database.create_collection(collection_name)
|
| 934 |
+
logger.info(f"Created collection: {collection_name}")
|
| 935 |
+
else:
|
| 936 |
+
logger.debug(f"Collection '{collection_name}' already exists.")
|
lightrag/kg/nano_vector_db_impl.py
CHANGED
|
@@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
| 191 |
except Exception as e:
|
| 192 |
logger.error(f"Error deleting entity {entity_name}: {e}")
|
| 193 |
|
| 194 |
-
async def delete_entity_relation(self, entity_name: str):
|
| 195 |
try:
|
| 196 |
relations = [
|
| 197 |
dp
|
|
|
|
| 191 |
except Exception as e:
|
| 192 |
logger.error(f"Error deleting entity {entity_name}: {e}")
|
| 193 |
|
| 194 |
+
async def delete_entity_relation(self, entity_name: str) -> None:
|
| 195 |
try:
|
| 196 |
relations = [
|
| 197 |
dp
|
lightrag/kg/neo4j_impl.py
CHANGED
|
@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 143 |
async def index_done_callback(self):
|
| 144 |
print("KG successfully indexed.")
|
| 145 |
|
| 146 |
-
async def
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
async with self._driver.session(database=self._DATABASE) as session:
|
| 150 |
query = (
|
| 151 |
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
|
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 174 |
return single_result["edgeExists"]
|
| 175 |
|
| 176 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
async with self._driver.session(database=self._DATABASE) as session:
|
| 178 |
-
entity_name_label =
|
| 179 |
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
| 180 |
result = await session.run(query)
|
| 181 |
record = await result.single()
|
|
@@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 226 |
async def get_edge(
|
| 227 |
self, source_node_id: str, target_node_id: str
|
| 228 |
) -> Union[dict, None]:
|
| 229 |
-
|
| 230 |
-
entity_name_label_target = target_node_id.strip('"')
|
| 231 |
-
"""
|
| 232 |
-
Find all edges between nodes of two given labels
|
| 233 |
|
| 234 |
Args:
|
| 235 |
-
|
| 236 |
-
|
| 237 |
|
| 238 |
Returns:
|
| 239 |
-
|
|
|
|
| 240 |
"""
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
entity_name_label_source
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
-
result = await session.run(query)
|
| 252 |
-
record = await result.single()
|
| 253 |
-
if record:
|
| 254 |
-
result = dict(record["edge_properties"])
|
| 255 |
logger.debug(
|
| 256 |
-
f"{inspect.currentframe().f_code.co_name}:
|
| 257 |
)
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
| 263 |
node_label = source_node_id.strip('"')
|
|
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 310 |
node_id: The unique identifier for the node (used as label)
|
| 311 |
node_data: Dictionary of node properties
|
| 312 |
"""
|
| 313 |
-
label =
|
| 314 |
properties = node_data
|
| 315 |
|
| 316 |
async def _do_upsert(tx: AsyncManagedTransaction):
|
|
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 338 |
neo4jExceptions.ServiceUnavailable,
|
| 339 |
neo4jExceptions.TransientError,
|
| 340 |
neo4jExceptions.WriteServiceUnavailable,
|
|
|
|
| 341 |
)
|
| 342 |
),
|
| 343 |
)
|
|
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 352 |
target_node_id (str): Label of the target node (used as identifier)
|
| 353 |
edge_data (dict): Dictionary of properties to set on the edge
|
| 354 |
"""
|
| 355 |
-
|
| 356 |
-
|
| 357 |
edge_properties = edge_data
|
| 358 |
|
| 359 |
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
| 360 |
query = f"""
|
| 361 |
-
MATCH (source:`{
|
| 362 |
WITH source
|
| 363 |
-
MATCH (target:`{
|
| 364 |
MERGE (source)-[r:DIRECTED]->(target)
|
| 365 |
SET r += $properties
|
| 366 |
RETURN r
|
| 367 |
"""
|
| 368 |
-
await tx.run(query, properties=edge_properties)
|
|
|
|
| 369 |
logger.debug(
|
| 370 |
-
f"Upserted edge from '{
|
| 371 |
)
|
| 372 |
|
| 373 |
try:
|
|
|
|
| 143 |
async def index_done_callback(self):
|
| 144 |
print("KG successfully indexed.")
|
| 145 |
|
| 146 |
+
async def _label_exists(self, label: str) -> bool:
|
| 147 |
+
"""Check if a label exists in the Neo4j database."""
|
| 148 |
+
query = "CALL db.labels() YIELD label RETURN label"
|
| 149 |
+
try:
|
| 150 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
| 151 |
+
result = await session.run(query)
|
| 152 |
+
labels = [record["label"] for record in await result.data()]
|
| 153 |
+
return label in labels
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.error(f"Error checking label existence: {e}")
|
| 156 |
+
return False
|
| 157 |
|
| 158 |
+
async def _ensure_label(self, label: str) -> str:
|
| 159 |
+
"""Ensure a label exists by validating it."""
|
| 160 |
+
clean_label = label.strip('"')
|
| 161 |
+
if not await self._label_exists(clean_label):
|
| 162 |
+
logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
|
| 163 |
+
return clean_label
|
| 164 |
+
|
| 165 |
+
async def has_node(self, node_id: str) -> bool:
|
| 166 |
+
entity_name_label = await self._ensure_label(node_id)
|
| 167 |
async with self._driver.session(database=self._DATABASE) as session:
|
| 168 |
query = (
|
| 169 |
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
|
|
|
| 192 |
return single_result["edgeExists"]
|
| 193 |
|
| 194 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
| 195 |
+
"""Get node by its label identifier.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
node_id: The node label to look up
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
dict: Node properties if found
|
| 202 |
+
None: If node not found
|
| 203 |
+
"""
|
| 204 |
async with self._driver.session(database=self._DATABASE) as session:
|
| 205 |
+
entity_name_label = await self._ensure_label(node_id)
|
| 206 |
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
| 207 |
result = await session.run(query)
|
| 208 |
record = await result.single()
|
|
|
|
| 253 |
async def get_edge(
|
| 254 |
self, source_node_id: str, target_node_id: str
|
| 255 |
) -> Union[dict, None]:
|
| 256 |
+
"""Find edge between two nodes identified by their labels.
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
Args:
|
| 259 |
+
source_node_id (str): Label of the source node
|
| 260 |
+
target_node_id (str): Label of the target node
|
| 261 |
|
| 262 |
Returns:
|
| 263 |
+
dict: Edge properties if found, with at least {"weight": 0.0}
|
| 264 |
+
None: If error occurs
|
| 265 |
"""
|
| 266 |
+
try:
|
| 267 |
+
entity_name_label_source = source_node_id.strip('"')
|
| 268 |
+
entity_name_label_target = target_node_id.strip('"')
|
| 269 |
+
|
| 270 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
| 271 |
+
query = f"""
|
| 272 |
+
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
| 273 |
+
RETURN properties(r) as edge_properties
|
| 274 |
+
LIMIT 1
|
| 275 |
+
""".format(
|
| 276 |
+
entity_name_label_source=entity_name_label_source,
|
| 277 |
+
entity_name_label_target=entity_name_label_target,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
result = await session.run(query)
|
| 281 |
+
record = await result.single()
|
| 282 |
+
if record and "edge_properties" in record:
|
| 283 |
+
try:
|
| 284 |
+
result = dict(record["edge_properties"])
|
| 285 |
+
# Ensure required keys exist with defaults
|
| 286 |
+
required_keys = {
|
| 287 |
+
"weight": 0.0,
|
| 288 |
+
"source_id": None,
|
| 289 |
+
"target_id": None,
|
| 290 |
+
}
|
| 291 |
+
for key, default_value in required_keys.items():
|
| 292 |
+
if key not in result:
|
| 293 |
+
result[key] = default_value
|
| 294 |
+
logger.warning(
|
| 295 |
+
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
|
| 296 |
+
f"missing {key}, using default: {default_value}"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
logger.debug(
|
| 300 |
+
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
| 301 |
+
)
|
| 302 |
+
return result
|
| 303 |
+
except (KeyError, TypeError, ValueError) as e:
|
| 304 |
+
logger.error(
|
| 305 |
+
f"Error processing edge properties between {entity_name_label_source} "
|
| 306 |
+
f"and {entity_name_label_target}: {str(e)}"
|
| 307 |
+
)
|
| 308 |
+
# Return default edge properties on error
|
| 309 |
+
return {"weight": 0.0, "source_id": None, "target_id": None}
|
| 310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
logger.debug(
|
| 312 |
+
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
| 313 |
)
|
| 314 |
+
# Return default edge properties when no edge found
|
| 315 |
+
return {"weight": 0.0, "source_id": None, "target_id": None}
|
| 316 |
+
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.error(
|
| 319 |
+
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
|
| 320 |
+
)
|
| 321 |
+
# Return default edge properties on error
|
| 322 |
+
return {"weight": 0.0, "source_id": None, "target_id": None}
|
| 323 |
|
| 324 |
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
| 325 |
node_label = source_node_id.strip('"')
|
|
|
|
| 372 |
node_id: The unique identifier for the node (used as label)
|
| 373 |
node_data: Dictionary of node properties
|
| 374 |
"""
|
| 375 |
+
label = await self._ensure_label(node_id)
|
| 376 |
properties = node_data
|
| 377 |
|
| 378 |
async def _do_upsert(tx: AsyncManagedTransaction):
|
|
|
|
| 400 |
neo4jExceptions.ServiceUnavailable,
|
| 401 |
neo4jExceptions.TransientError,
|
| 402 |
neo4jExceptions.WriteServiceUnavailable,
|
| 403 |
+
neo4jExceptions.ClientError,
|
| 404 |
)
|
| 405 |
),
|
| 406 |
)
|
|
|
|
| 415 |
target_node_id (str): Label of the target node (used as identifier)
|
| 416 |
edge_data (dict): Dictionary of properties to set on the edge
|
| 417 |
"""
|
| 418 |
+
source_label = await self._ensure_label(source_node_id)
|
| 419 |
+
target_label = await self._ensure_label(target_node_id)
|
| 420 |
edge_properties = edge_data
|
| 421 |
|
| 422 |
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
| 423 |
query = f"""
|
| 424 |
+
MATCH (source:`{source_label}`)
|
| 425 |
WITH source
|
| 426 |
+
MATCH (target:`{target_label}`)
|
| 427 |
MERGE (source)-[r:DIRECTED]->(target)
|
| 428 |
SET r += $properties
|
| 429 |
RETURN r
|
| 430 |
"""
|
| 431 |
+
result = await tx.run(query, properties=edge_properties)
|
| 432 |
+
record = await result.single()
|
| 433 |
logger.debug(
|
| 434 |
+
f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
|
| 435 |
)
|
| 436 |
|
| 437 |
try:
|
lightrag/lightrag.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
import configparser
|
| 4 |
from dataclasses import asdict, dataclass, field
|
| 5 |
from datetime import datetime
|
| 6 |
from functools import partial
|
| 7 |
-
from typing import Any,
|
| 8 |
|
| 9 |
from .base import (
|
| 10 |
BaseGraphStorage,
|
|
@@ -76,6 +78,7 @@ STORAGE_IMPLEMENTATIONS = {
|
|
| 76 |
"FaissVectorDBStorage",
|
| 77 |
"QdrantVectorDBStorage",
|
| 78 |
"OracleVectorDBStorage",
|
|
|
|
| 79 |
],
|
| 80 |
"required_methods": ["query", "upsert"],
|
| 81 |
},
|
|
@@ -91,7 +94,7 @@ STORAGE_IMPLEMENTATIONS = {
|
|
| 91 |
}
|
| 92 |
|
| 93 |
# Storage implementation environment variable without default value
|
| 94 |
-
STORAGE_ENV_REQUIREMENTS = {
|
| 95 |
# KV Storage Implementations
|
| 96 |
"JsonKVStorage": [],
|
| 97 |
"MongoKVStorage": [],
|
|
@@ -140,6 +143,7 @@ STORAGE_ENV_REQUIREMENTS = {
|
|
| 140 |
"ORACLE_PASSWORD",
|
| 141 |
"ORACLE_CONFIG_DIR",
|
| 142 |
],
|
|
|
|
| 143 |
# Document Status Storage Implementations
|
| 144 |
"JsonDocStatusStorage": [],
|
| 145 |
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
@@ -160,6 +164,7 @@ STORAGES = {
|
|
| 160 |
"MongoKVStorage": ".kg.mongo_impl",
|
| 161 |
"MongoDocStatusStorage": ".kg.mongo_impl",
|
| 162 |
"MongoGraphStorage": ".kg.mongo_impl",
|
|
|
|
| 163 |
"RedisKVStorage": ".kg.redis_impl",
|
| 164 |
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
| 165 |
"TiDBKVStorage": ".kg.tidb_impl",
|
|
@@ -176,7 +181,7 @@ STORAGES = {
|
|
| 176 |
}
|
| 177 |
|
| 178 |
|
| 179 |
-
def lazy_external_import(module_name: str, class_name: str):
|
| 180 |
"""Lazily import a class from an external module based on the package of the caller."""
|
| 181 |
# Get the caller's module and package
|
| 182 |
import inspect
|
|
@@ -185,7 +190,7 @@ def lazy_external_import(module_name: str, class_name: str):
|
|
| 185 |
module = inspect.getmodule(caller_frame)
|
| 186 |
package = module.__package__ if module else None
|
| 187 |
|
| 188 |
-
def import_class(*args, **kwargs):
|
| 189 |
import importlib
|
| 190 |
|
| 191 |
module = importlib.import_module(module_name, package=package)
|
|
@@ -302,7 +307,7 @@ class LightRAG:
|
|
| 302 |
- random_seed: Seed value for reproducibility.
|
| 303 |
"""
|
| 304 |
|
| 305 |
-
embedding_func: EmbeddingFunc = None
|
| 306 |
"""Function for computing text embeddings. Must be set before use."""
|
| 307 |
|
| 308 |
embedding_batch_num: int = 32
|
|
@@ -312,7 +317,7 @@ class LightRAG:
|
|
| 312 |
"""Maximum number of concurrent embedding function calls."""
|
| 313 |
|
| 314 |
# LLM Configuration
|
| 315 |
-
llm_model_func:
|
| 316 |
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
| 317 |
|
| 318 |
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
|
@@ -342,10 +347,8 @@ class LightRAG:
|
|
| 342 |
|
| 343 |
# Extensions
|
| 344 |
addon_params: dict[str, Any] = field(default_factory=dict)
|
| 345 |
-
"""Dictionary for additional parameters and extensions."""
|
| 346 |
|
| 347 |
-
|
| 348 |
-
addon_params: dict[str, Any] = field(default_factory=dict)
|
| 349 |
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
| 350 |
convert_response_to_json
|
| 351 |
)
|
|
@@ -354,7 +357,7 @@ class LightRAG:
|
|
| 354 |
chunking_func: Callable[
|
| 355 |
[
|
| 356 |
str,
|
| 357 |
-
|
| 358 |
bool,
|
| 359 |
int,
|
| 360 |
int,
|
|
@@ -443,77 +446,74 @@ class LightRAG:
|
|
| 443 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
| 444 |
|
| 445 |
# Init LLM
|
| 446 |
-
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
| 447 |
self.embedding_func
|
| 448 |
)
|
| 449 |
|
| 450 |
# Initialize all storages
|
| 451 |
-
self.key_string_value_json_storage_cls:
|
| 452 |
self._get_storage_class(self.kv_storage)
|
| 453 |
-
)
|
| 454 |
-
self.vector_db_storage_cls:
|
| 455 |
self.vector_storage
|
| 456 |
-
)
|
| 457 |
-
self.graph_storage_cls:
|
| 458 |
self.graph_storage
|
| 459 |
-
)
|
| 460 |
-
|
| 461 |
-
self.key_string_value_json_storage_cls = partial(
|
| 462 |
self.key_string_value_json_storage_cls, global_config=global_config
|
| 463 |
)
|
| 464 |
-
|
| 465 |
-
self.vector_db_storage_cls = partial(
|
| 466 |
self.vector_db_storage_cls, global_config=global_config
|
| 467 |
)
|
| 468 |
-
|
| 469 |
-
self.graph_storage_cls = partial(
|
| 470 |
self.graph_storage_cls, global_config=global_config
|
| 471 |
)
|
| 472 |
|
| 473 |
# Initialize document status storage
|
| 474 |
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
| 475 |
|
| 476 |
-
self.llm_response_cache = self.key_string_value_json_storage_cls(
|
| 477 |
namespace=make_namespace(
|
| 478 |
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
| 479 |
),
|
| 480 |
embedding_func=self.embedding_func,
|
| 481 |
)
|
| 482 |
|
| 483 |
-
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
|
| 484 |
namespace=make_namespace(
|
| 485 |
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
| 486 |
),
|
| 487 |
embedding_func=self.embedding_func,
|
| 488 |
)
|
| 489 |
-
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
|
| 490 |
namespace=make_namespace(
|
| 491 |
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
| 492 |
),
|
| 493 |
embedding_func=self.embedding_func,
|
| 494 |
)
|
| 495 |
-
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
|
| 496 |
namespace=make_namespace(
|
| 497 |
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
| 498 |
),
|
| 499 |
embedding_func=self.embedding_func,
|
| 500 |
)
|
| 501 |
|
| 502 |
-
self.entities_vdb = self.vector_db_storage_cls(
|
| 503 |
namespace=make_namespace(
|
| 504 |
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
| 505 |
),
|
| 506 |
embedding_func=self.embedding_func,
|
| 507 |
meta_fields={"entity_name"},
|
| 508 |
)
|
| 509 |
-
self.relationships_vdb = self.vector_db_storage_cls(
|
| 510 |
namespace=make_namespace(
|
| 511 |
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
| 512 |
),
|
| 513 |
embedding_func=self.embedding_func,
|
| 514 |
meta_fields={"src_id", "tgt_id"},
|
| 515 |
)
|
| 516 |
-
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
|
| 517 |
namespace=make_namespace(
|
| 518 |
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
| 519 |
),
|
|
@@ -527,13 +527,12 @@ class LightRAG:
|
|
| 527 |
embedding_func=None,
|
| 528 |
)
|
| 529 |
|
| 530 |
-
# What's for, Is this nessisary ?
|
| 531 |
if self.llm_response_cache and hasattr(
|
| 532 |
self.llm_response_cache, "global_config"
|
| 533 |
):
|
| 534 |
hashing_kv = self.llm_response_cache
|
| 535 |
else:
|
| 536 |
-
hashing_kv = self.key_string_value_json_storage_cls(
|
| 537 |
namespace=make_namespace(
|
| 538 |
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
| 539 |
),
|
|
@@ -542,7 +541,7 @@ class LightRAG:
|
|
| 542 |
|
| 543 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
| 544 |
partial(
|
| 545 |
-
self.llm_model_func,
|
| 546 |
hashing_kv=hashing_kv,
|
| 547 |
**self.llm_model_kwargs,
|
| 548 |
)
|
|
@@ -559,68 +558,45 @@ class LightRAG:
|
|
| 559 |
node_label=nodel_label, max_depth=max_depth
|
| 560 |
)
|
| 561 |
|
| 562 |
-
def _get_storage_class(self, storage_name: str) ->
|
| 563 |
import_path = STORAGES[storage_name]
|
| 564 |
storage_class = lazy_external_import(import_path, storage_name)
|
| 565 |
return storage_class
|
| 566 |
|
| 567 |
-
def set_storage_client(self, db_client):
|
| 568 |
-
# Deprecated, seting correct value to *_storage of LightRAG insteaded
|
| 569 |
-
# Inject db to storage implementation (only tested on Oracle Database)
|
| 570 |
-
for storage in [
|
| 571 |
-
self.vector_db_storage_cls,
|
| 572 |
-
self.graph_storage_cls,
|
| 573 |
-
self.doc_status,
|
| 574 |
-
self.full_docs,
|
| 575 |
-
self.text_chunks,
|
| 576 |
-
self.llm_response_cache,
|
| 577 |
-
self.key_string_value_json_storage_cls,
|
| 578 |
-
self.chunks_vdb,
|
| 579 |
-
self.relationships_vdb,
|
| 580 |
-
self.entities_vdb,
|
| 581 |
-
self.graph_storage_cls,
|
| 582 |
-
self.chunk_entity_relation_graph,
|
| 583 |
-
self.llm_response_cache,
|
| 584 |
-
]:
|
| 585 |
-
# set client
|
| 586 |
-
storage.db = db_client
|
| 587 |
-
|
| 588 |
def insert(
|
| 589 |
self,
|
| 590 |
-
|
| 591 |
split_by_character: str | None = None,
|
| 592 |
split_by_character_only: bool = False,
|
| 593 |
):
|
| 594 |
"""Sync Insert documents with checkpoint support
|
| 595 |
|
| 596 |
Args:
|
| 597 |
-
|
| 598 |
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
| 599 |
-
chunk_size, split the sub chunk by token size.
|
| 600 |
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
| 601 |
split_by_character is None, this parameter is ignored.
|
| 602 |
"""
|
| 603 |
loop = always_get_an_event_loop()
|
| 604 |
return loop.run_until_complete(
|
| 605 |
-
self.ainsert(
|
| 606 |
)
|
| 607 |
|
| 608 |
async def ainsert(
|
| 609 |
self,
|
| 610 |
-
|
| 611 |
split_by_character: str | None = None,
|
| 612 |
split_by_character_only: bool = False,
|
| 613 |
):
|
| 614 |
"""Async Insert documents with checkpoint support
|
| 615 |
|
| 616 |
Args:
|
| 617 |
-
|
| 618 |
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
| 619 |
-
chunk_size, split the sub chunk by token size.
|
| 620 |
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
| 621 |
split_by_character is None, this parameter is ignored.
|
| 622 |
"""
|
| 623 |
-
await self.apipeline_enqueue_documents(
|
| 624 |
await self.apipeline_process_enqueue_documents(
|
| 625 |
split_by_character, split_by_character_only
|
| 626 |
)
|
|
@@ -677,7 +653,7 @@ class LightRAG:
|
|
| 677 |
if update_storage:
|
| 678 |
await self._insert_done()
|
| 679 |
|
| 680 |
-
async def apipeline_enqueue_documents(self,
|
| 681 |
"""
|
| 682 |
Pipeline for Processing Documents
|
| 683 |
|
|
@@ -686,11 +662,11 @@ class LightRAG:
|
|
| 686 |
3. Filter out already processed documents
|
| 687 |
4. Enqueue document in status
|
| 688 |
"""
|
| 689 |
-
if isinstance(
|
| 690 |
-
|
| 691 |
|
| 692 |
# 1. Remove duplicate contents from the list
|
| 693 |
-
unique_contents = list(set(doc.strip() for doc in
|
| 694 |
|
| 695 |
# 2. Generate document IDs and initial status
|
| 696 |
new_docs: dict[str, Any] = {
|
|
@@ -857,32 +833,32 @@ class LightRAG:
|
|
| 857 |
raise e
|
| 858 |
|
| 859 |
async def _insert_done(self):
|
| 860 |
-
tasks = [
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
await asyncio.gather(*tasks)
|
| 874 |
|
| 875 |
-
def insert_custom_kg(self, custom_kg: dict):
|
| 876 |
loop = always_get_an_event_loop()
|
| 877 |
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
| 878 |
|
| 879 |
-
async def ainsert_custom_kg(self, custom_kg: dict):
|
| 880 |
update_storage = False
|
| 881 |
try:
|
| 882 |
# Insert chunks into vector storage
|
| 883 |
-
all_chunks_data = {}
|
| 884 |
-
chunk_to_source_map = {}
|
| 885 |
-
for chunk_data in custom_kg.get("chunks",
|
| 886 |
chunk_content = chunk_data["content"]
|
| 887 |
source_id = chunk_data["source_id"]
|
| 888 |
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
|
@@ -892,13 +868,13 @@ class LightRAG:
|
|
| 892 |
chunk_to_source_map[source_id] = chunk_id
|
| 893 |
update_storage = True
|
| 894 |
|
| 895 |
-
if
|
| 896 |
await self.chunks_vdb.upsert(all_chunks_data)
|
| 897 |
-
if
|
| 898 |
await self.text_chunks.upsert(all_chunks_data)
|
| 899 |
|
| 900 |
# Insert entities into knowledge graph
|
| 901 |
-
all_entities_data = []
|
| 902 |
for entity_data in custom_kg.get("entities", []):
|
| 903 |
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
| 904 |
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
|
@@ -914,7 +890,7 @@ class LightRAG:
|
|
| 914 |
)
|
| 915 |
|
| 916 |
# Prepare node data
|
| 917 |
-
node_data = {
|
| 918 |
"entity_type": entity_type,
|
| 919 |
"description": description,
|
| 920 |
"source_id": source_id,
|
|
@@ -928,7 +904,7 @@ class LightRAG:
|
|
| 928 |
update_storage = True
|
| 929 |
|
| 930 |
# Insert relationships into knowledge graph
|
| 931 |
-
all_relationships_data = []
|
| 932 |
for relationship_data in custom_kg.get("relationships", []):
|
| 933 |
src_id = f'"{relationship_data["src_id"].upper()}"'
|
| 934 |
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
|
@@ -970,7 +946,7 @@ class LightRAG:
|
|
| 970 |
"source_id": source_id,
|
| 971 |
},
|
| 972 |
)
|
| 973 |
-
edge_data = {
|
| 974 |
"src_id": src_id,
|
| 975 |
"tgt_id": tgt_id,
|
| 976 |
"description": description,
|
|
@@ -980,41 +956,68 @@ class LightRAG:
|
|
| 980 |
update_storage = True
|
| 981 |
|
| 982 |
# Insert entities into vector storage if needed
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
"entity_name": dp["entity_name"],
|
| 988 |
-
}
|
| 989 |
-
for dp in all_entities_data
|
| 990 |
}
|
| 991 |
-
|
|
|
|
|
|
|
| 992 |
|
| 993 |
# Insert relationships into vector storage if needed
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
+ dp["description"],
|
| 1003 |
-
}
|
| 1004 |
-
for dp in all_relationships_data
|
| 1005 |
}
|
| 1006 |
-
|
|
|
|
|
|
|
|
|
|
| 1007 |
finally:
|
| 1008 |
if update_storage:
|
| 1009 |
await self._insert_done()
|
| 1010 |
|
| 1011 |
-
def query(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
loop = always_get_an_event_loop()
|
| 1013 |
-
|
|
|
|
| 1014 |
|
| 1015 |
async def aquery(
|
| 1016 |
-
self,
|
| 1017 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1018 |
if param.mode in ["local", "global", "hybrid"]:
|
| 1019 |
response = await kg_query(
|
| 1020 |
query,
|
|
@@ -1094,7 +1097,7 @@ class LightRAG:
|
|
| 1094 |
|
| 1095 |
async def aquery_with_separate_keyword_extraction(
|
| 1096 |
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
| 1097 |
-
):
|
| 1098 |
"""
|
| 1099 |
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
| 1100 |
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
|
@@ -1117,8 +1120,8 @@ class LightRAG:
|
|
| 1117 |
),
|
| 1118 |
)
|
| 1119 |
|
| 1120 |
-
param.hl_keywords =
|
| 1121 |
-
param.ll_keywords =
|
| 1122 |
|
| 1123 |
# ---------------------
|
| 1124 |
# STEP 2: Final Query Logic
|
|
@@ -1146,7 +1149,7 @@ class LightRAG:
|
|
| 1146 |
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
| 1147 |
),
|
| 1148 |
global_config=asdict(self),
|
| 1149 |
-
embedding_func=self.
|
| 1150 |
),
|
| 1151 |
)
|
| 1152 |
elif param.mode == "naive":
|
|
@@ -1195,12 +1198,7 @@ class LightRAG:
|
|
| 1195 |
return response
|
| 1196 |
|
| 1197 |
async def _query_done(self):
|
| 1198 |
-
|
| 1199 |
-
for storage_inst in [self.llm_response_cache]:
|
| 1200 |
-
if storage_inst is None:
|
| 1201 |
-
continue
|
| 1202 |
-
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
| 1203 |
-
await asyncio.gather(*tasks)
|
| 1204 |
|
| 1205 |
def delete_by_entity(self, entity_name: str):
|
| 1206 |
loop = always_get_an_event_loop()
|
|
@@ -1222,16 +1220,16 @@ class LightRAG:
|
|
| 1222 |
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
| 1223 |
|
| 1224 |
async def _delete_by_entity_done(self):
|
| 1225 |
-
|
| 1226 |
-
|
| 1227 |
-
|
| 1228 |
-
|
| 1229 |
-
|
| 1230 |
-
|
| 1231 |
-
|
| 1232 |
-
|
| 1233 |
-
|
| 1234 |
-
|
| 1235 |
|
| 1236 |
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
| 1237 |
"""Get summary of document content
|
|
@@ -1256,7 +1254,7 @@ class LightRAG:
|
|
| 1256 |
"""
|
| 1257 |
return await self.doc_status.get_status_counts()
|
| 1258 |
|
| 1259 |
-
async def adelete_by_doc_id(self, doc_id: str):
|
| 1260 |
"""Delete a document and all its related data
|
| 1261 |
|
| 1262 |
Args:
|
|
@@ -1273,6 +1271,9 @@ class LightRAG:
|
|
| 1273 |
|
| 1274 |
# 2. Get all related chunks
|
| 1275 |
chunks = await self.text_chunks.get_by_id(doc_id)
|
|
|
|
|
|
|
|
|
|
| 1276 |
chunk_ids = list(chunks.keys())
|
| 1277 |
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
| 1278 |
|
|
@@ -1443,13 +1444,9 @@ class LightRAG:
|
|
| 1443 |
except Exception as e:
|
| 1444 |
logger.error(f"Error while deleting document {doc_id}: {e}")
|
| 1445 |
|
| 1446 |
-
def delete_by_doc_id(self, doc_id: str):
|
| 1447 |
-
"""Synchronous version of adelete"""
|
| 1448 |
-
return asyncio.run(self.adelete_by_doc_id(doc_id))
|
| 1449 |
-
|
| 1450 |
async def get_entity_info(
|
| 1451 |
self, entity_name: str, include_vector_data: bool = False
|
| 1452 |
-
):
|
| 1453 |
"""Get detailed information of an entity
|
| 1454 |
|
| 1455 |
Args:
|
|
@@ -1469,7 +1466,7 @@ class LightRAG:
|
|
| 1469 |
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
| 1470 |
source_id = node_data.get("source_id") if node_data else None
|
| 1471 |
|
| 1472 |
-
result = {
|
| 1473 |
"entity_name": entity_name,
|
| 1474 |
"source_id": source_id,
|
| 1475 |
"graph_data": node_data,
|
|
@@ -1483,21 +1480,6 @@ class LightRAG:
|
|
| 1483 |
|
| 1484 |
return result
|
| 1485 |
|
| 1486 |
-
def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
|
| 1487 |
-
"""Synchronous version of getting entity information
|
| 1488 |
-
|
| 1489 |
-
Args:
|
| 1490 |
-
entity_name: Entity name (no need for quotes)
|
| 1491 |
-
include_vector_data: Whether to include data from the vector database
|
| 1492 |
-
"""
|
| 1493 |
-
try:
|
| 1494 |
-
import tracemalloc
|
| 1495 |
-
|
| 1496 |
-
tracemalloc.start()
|
| 1497 |
-
return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
|
| 1498 |
-
finally:
|
| 1499 |
-
tracemalloc.stop()
|
| 1500 |
-
|
| 1501 |
async def get_relation_info(
|
| 1502 |
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
| 1503 |
):
|
|
@@ -1525,7 +1507,7 @@ class LightRAG:
|
|
| 1525 |
)
|
| 1526 |
source_id = edge_data.get("source_id") if edge_data else None
|
| 1527 |
|
| 1528 |
-
result = {
|
| 1529 |
"src_entity": src_entity,
|
| 1530 |
"tgt_entity": tgt_entity,
|
| 1531 |
"source_id": source_id,
|
|
@@ -1539,23 +1521,3 @@ class LightRAG:
|
|
| 1539 |
result["vector_data"] = vector_data[0] if vector_data else None
|
| 1540 |
|
| 1541 |
return result
|
| 1542 |
-
|
| 1543 |
-
def get_relation_info_sync(
|
| 1544 |
-
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
| 1545 |
-
):
|
| 1546 |
-
"""Synchronous version of getting relationship information
|
| 1547 |
-
|
| 1548 |
-
Args:
|
| 1549 |
-
src_entity: Source entity name (no need for quotes)
|
| 1550 |
-
tgt_entity: Target entity name (no need for quotes)
|
| 1551 |
-
include_vector_data: Whether to include data from the vector database
|
| 1552 |
-
"""
|
| 1553 |
-
try:
|
| 1554 |
-
import tracemalloc
|
| 1555 |
-
|
| 1556 |
-
tracemalloc.start()
|
| 1557 |
-
return asyncio.run(
|
| 1558 |
-
self.get_relation_info(src_entity, tgt_entity, include_vector_data)
|
| 1559 |
-
)
|
| 1560 |
-
finally:
|
| 1561 |
-
tracemalloc.stop()
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import asyncio
|
| 4 |
import os
|
| 5 |
import configparser
|
| 6 |
from dataclasses import asdict, dataclass, field
|
| 7 |
from datetime import datetime
|
| 8 |
from functools import partial
|
| 9 |
+
from typing import Any, AsyncIterator, Callable, Iterator, cast
|
| 10 |
|
| 11 |
from .base import (
|
| 12 |
BaseGraphStorage,
|
|
|
|
| 78 |
"FaissVectorDBStorage",
|
| 79 |
"QdrantVectorDBStorage",
|
| 80 |
"OracleVectorDBStorage",
|
| 81 |
+
"MongoVectorDBStorage",
|
| 82 |
],
|
| 83 |
"required_methods": ["query", "upsert"],
|
| 84 |
},
|
|
|
|
| 94 |
}
|
| 95 |
|
| 96 |
# Storage implementation environment variable without default value
|
| 97 |
+
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
|
| 98 |
# KV Storage Implementations
|
| 99 |
"JsonKVStorage": [],
|
| 100 |
"MongoKVStorage": [],
|
|
|
|
| 143 |
"ORACLE_PASSWORD",
|
| 144 |
"ORACLE_CONFIG_DIR",
|
| 145 |
],
|
| 146 |
+
"MongoVectorDBStorage": [],
|
| 147 |
# Document Status Storage Implementations
|
| 148 |
"JsonDocStatusStorage": [],
|
| 149 |
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
|
|
|
|
| 164 |
"MongoKVStorage": ".kg.mongo_impl",
|
| 165 |
"MongoDocStatusStorage": ".kg.mongo_impl",
|
| 166 |
"MongoGraphStorage": ".kg.mongo_impl",
|
| 167 |
+
"MongoVectorDBStorage": ".kg.mongo_impl",
|
| 168 |
"RedisKVStorage": ".kg.redis_impl",
|
| 169 |
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
| 170 |
"TiDBKVStorage": ".kg.tidb_impl",
|
|
|
|
| 181 |
}
|
| 182 |
|
| 183 |
|
| 184 |
+
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
| 185 |
"""Lazily import a class from an external module based on the package of the caller."""
|
| 186 |
# Get the caller's module and package
|
| 187 |
import inspect
|
|
|
|
| 190 |
module = inspect.getmodule(caller_frame)
|
| 191 |
package = module.__package__ if module else None
|
| 192 |
|
| 193 |
+
def import_class(*args: Any, **kwargs: Any):
|
| 194 |
import importlib
|
| 195 |
|
| 196 |
module = importlib.import_module(module_name, package=package)
|
|
|
|
| 307 |
- random_seed: Seed value for reproducibility.
|
| 308 |
"""
|
| 309 |
|
| 310 |
+
embedding_func: EmbeddingFunc | None = None
|
| 311 |
"""Function for computing text embeddings. Must be set before use."""
|
| 312 |
|
| 313 |
embedding_batch_num: int = 32
|
|
|
|
| 317 |
"""Maximum number of concurrent embedding function calls."""
|
| 318 |
|
| 319 |
# LLM Configuration
|
| 320 |
+
llm_model_func: Callable[..., object] | None = None
|
| 321 |
"""Function for interacting with the large language model (LLM). Must be set before use."""
|
| 322 |
|
| 323 |
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
|
|
|
| 347 |
|
| 348 |
# Extensions
|
| 349 |
addon_params: dict[str, Any] = field(default_factory=dict)
|
|
|
|
| 350 |
|
| 351 |
+
"""Dictionary for additional parameters and extensions."""
|
|
|
|
| 352 |
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
|
| 353 |
convert_response_to_json
|
| 354 |
)
|
|
|
|
| 357 |
chunking_func: Callable[
|
| 358 |
[
|
| 359 |
str,
|
| 360 |
+
str | None,
|
| 361 |
bool,
|
| 362 |
int,
|
| 363 |
int,
|
|
|
|
| 446 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
| 447 |
|
| 448 |
# Init LLM
|
| 449 |
+
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
| 450 |
self.embedding_func
|
| 451 |
)
|
| 452 |
|
| 453 |
# Initialize all storages
|
| 454 |
+
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
|
| 455 |
self._get_storage_class(self.kv_storage)
|
| 456 |
+
) # type: ignore
|
| 457 |
+
self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
|
| 458 |
self.vector_storage
|
| 459 |
+
) # type: ignore
|
| 460 |
+
self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
|
| 461 |
self.graph_storage
|
| 462 |
+
) # type: ignore
|
| 463 |
+
self.key_string_value_json_storage_cls = partial( # type: ignore
|
|
|
|
| 464 |
self.key_string_value_json_storage_cls, global_config=global_config
|
| 465 |
)
|
| 466 |
+
self.vector_db_storage_cls = partial( # type: ignore
|
|
|
|
| 467 |
self.vector_db_storage_cls, global_config=global_config
|
| 468 |
)
|
| 469 |
+
self.graph_storage_cls = partial( # type: ignore
|
|
|
|
| 470 |
self.graph_storage_cls, global_config=global_config
|
| 471 |
)
|
| 472 |
|
| 473 |
# Initialize document status storage
|
| 474 |
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
| 475 |
|
| 476 |
+
self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
| 477 |
namespace=make_namespace(
|
| 478 |
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
| 479 |
),
|
| 480 |
embedding_func=self.embedding_func,
|
| 481 |
)
|
| 482 |
|
| 483 |
+
self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
| 484 |
namespace=make_namespace(
|
| 485 |
self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
|
| 486 |
),
|
| 487 |
embedding_func=self.embedding_func,
|
| 488 |
)
|
| 489 |
+
self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
|
| 490 |
namespace=make_namespace(
|
| 491 |
self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
|
| 492 |
),
|
| 493 |
embedding_func=self.embedding_func,
|
| 494 |
)
|
| 495 |
+
self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
|
| 496 |
namespace=make_namespace(
|
| 497 |
self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
|
| 498 |
),
|
| 499 |
embedding_func=self.embedding_func,
|
| 500 |
)
|
| 501 |
|
| 502 |
+
self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
| 503 |
namespace=make_namespace(
|
| 504 |
self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
|
| 505 |
),
|
| 506 |
embedding_func=self.embedding_func,
|
| 507 |
meta_fields={"entity_name"},
|
| 508 |
)
|
| 509 |
+
self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
| 510 |
namespace=make_namespace(
|
| 511 |
self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
|
| 512 |
),
|
| 513 |
embedding_func=self.embedding_func,
|
| 514 |
meta_fields={"src_id", "tgt_id"},
|
| 515 |
)
|
| 516 |
+
self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
|
| 517 |
namespace=make_namespace(
|
| 518 |
self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
|
| 519 |
),
|
|
|
|
| 527 |
embedding_func=None,
|
| 528 |
)
|
| 529 |
|
|
|
|
| 530 |
if self.llm_response_cache and hasattr(
|
| 531 |
self.llm_response_cache, "global_config"
|
| 532 |
):
|
| 533 |
hashing_kv = self.llm_response_cache
|
| 534 |
else:
|
| 535 |
+
hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
|
| 536 |
namespace=make_namespace(
|
| 537 |
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
| 538 |
),
|
|
|
|
| 541 |
|
| 542 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
| 543 |
partial(
|
| 544 |
+
self.llm_model_func, # type: ignore
|
| 545 |
hashing_kv=hashing_kv,
|
| 546 |
**self.llm_model_kwargs,
|
| 547 |
)
|
|
|
|
| 558 |
node_label=nodel_label, max_depth=max_depth
|
| 559 |
)
|
| 560 |
|
| 561 |
+
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
| 562 |
import_path = STORAGES[storage_name]
|
| 563 |
storage_class = lazy_external_import(import_path, storage_name)
|
| 564 |
return storage_class
|
| 565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
def insert(
|
| 567 |
self,
|
| 568 |
+
input: str | list[str],
|
| 569 |
split_by_character: str | None = None,
|
| 570 |
split_by_character_only: bool = False,
|
| 571 |
):
|
| 572 |
"""Sync Insert documents with checkpoint support
|
| 573 |
|
| 574 |
Args:
|
| 575 |
+
input: Single document string or list of document strings
|
| 576 |
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
|
|
|
| 577 |
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
| 578 |
split_by_character is None, this parameter is ignored.
|
| 579 |
"""
|
| 580 |
loop = always_get_an_event_loop()
|
| 581 |
return loop.run_until_complete(
|
| 582 |
+
self.ainsert(input, split_by_character, split_by_character_only)
|
| 583 |
)
|
| 584 |
|
| 585 |
async def ainsert(
|
| 586 |
self,
|
| 587 |
+
input: str | list[str],
|
| 588 |
split_by_character: str | None = None,
|
| 589 |
split_by_character_only: bool = False,
|
| 590 |
):
|
| 591 |
"""Async Insert documents with checkpoint support
|
| 592 |
|
| 593 |
Args:
|
| 594 |
+
input: Single document string or list of document strings
|
| 595 |
split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
|
|
|
|
| 596 |
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
| 597 |
split_by_character is None, this parameter is ignored.
|
| 598 |
"""
|
| 599 |
+
await self.apipeline_enqueue_documents(input)
|
| 600 |
await self.apipeline_process_enqueue_documents(
|
| 601 |
split_by_character, split_by_character_only
|
| 602 |
)
|
|
|
|
| 653 |
if update_storage:
|
| 654 |
await self._insert_done()
|
| 655 |
|
| 656 |
+
async def apipeline_enqueue_documents(self, input: str | list[str]):
|
| 657 |
"""
|
| 658 |
Pipeline for Processing Documents
|
| 659 |
|
|
|
|
| 662 |
3. Filter out already processed documents
|
| 663 |
4. Enqueue document in status
|
| 664 |
"""
|
| 665 |
+
if isinstance(input, str):
|
| 666 |
+
input = [input]
|
| 667 |
|
| 668 |
# 1. Remove duplicate contents from the list
|
| 669 |
+
unique_contents = list(set(doc.strip() for doc in input))
|
| 670 |
|
| 671 |
# 2. Generate document IDs and initial status
|
| 672 |
new_docs: dict[str, Any] = {
|
|
|
|
| 833 |
raise e
|
| 834 |
|
| 835 |
async def _insert_done(self):
|
| 836 |
+
tasks = [
|
| 837 |
+
cast(StorageNameSpace, storage_inst).index_done_callback()
|
| 838 |
+
for storage_inst in [ # type: ignore
|
| 839 |
+
self.full_docs,
|
| 840 |
+
self.text_chunks,
|
| 841 |
+
self.llm_response_cache,
|
| 842 |
+
self.entities_vdb,
|
| 843 |
+
self.relationships_vdb,
|
| 844 |
+
self.chunks_vdb,
|
| 845 |
+
self.chunk_entity_relation_graph,
|
| 846 |
+
]
|
| 847 |
+
if storage_inst is not None
|
| 848 |
+
]
|
| 849 |
await asyncio.gather(*tasks)
|
| 850 |
|
| 851 |
+
def insert_custom_kg(self, custom_kg: dict[str, Any]):
|
| 852 |
loop = always_get_an_event_loop()
|
| 853 |
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
| 854 |
|
| 855 |
+
async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
|
| 856 |
update_storage = False
|
| 857 |
try:
|
| 858 |
# Insert chunks into vector storage
|
| 859 |
+
all_chunks_data: dict[str, dict[str, str]] = {}
|
| 860 |
+
chunk_to_source_map: dict[str, str] = {}
|
| 861 |
+
for chunk_data in custom_kg.get("chunks", {}):
|
| 862 |
chunk_content = chunk_data["content"]
|
| 863 |
source_id = chunk_data["source_id"]
|
| 864 |
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
|
|
|
| 868 |
chunk_to_source_map[source_id] = chunk_id
|
| 869 |
update_storage = True
|
| 870 |
|
| 871 |
+
if all_chunks_data:
|
| 872 |
await self.chunks_vdb.upsert(all_chunks_data)
|
| 873 |
+
if all_chunks_data:
|
| 874 |
await self.text_chunks.upsert(all_chunks_data)
|
| 875 |
|
| 876 |
# Insert entities into knowledge graph
|
| 877 |
+
all_entities_data: list[dict[str, str]] = []
|
| 878 |
for entity_data in custom_kg.get("entities", []):
|
| 879 |
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
| 880 |
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
|
|
|
| 890 |
)
|
| 891 |
|
| 892 |
# Prepare node data
|
| 893 |
+
node_data: dict[str, str] = {
|
| 894 |
"entity_type": entity_type,
|
| 895 |
"description": description,
|
| 896 |
"source_id": source_id,
|
|
|
|
| 904 |
update_storage = True
|
| 905 |
|
| 906 |
# Insert relationships into knowledge graph
|
| 907 |
+
all_relationships_data: list[dict[str, str]] = []
|
| 908 |
for relationship_data in custom_kg.get("relationships", []):
|
| 909 |
src_id = f'"{relationship_data["src_id"].upper()}"'
|
| 910 |
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
|
|
|
| 946 |
"source_id": source_id,
|
| 947 |
},
|
| 948 |
)
|
| 949 |
+
edge_data: dict[str, str] = {
|
| 950 |
"src_id": src_id,
|
| 951 |
"tgt_id": tgt_id,
|
| 952 |
"description": description,
|
|
|
|
| 956 |
update_storage = True
|
| 957 |
|
| 958 |
# Insert entities into vector storage if needed
|
| 959 |
+
data_for_vdb = {
|
| 960 |
+
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
| 961 |
+
"content": dp["entity_name"] + dp["description"],
|
| 962 |
+
"entity_name": dp["entity_name"],
|
|
|
|
|
|
|
|
|
|
| 963 |
}
|
| 964 |
+
for dp in all_entities_data
|
| 965 |
+
}
|
| 966 |
+
await self.entities_vdb.upsert(data_for_vdb)
|
| 967 |
|
| 968 |
# Insert relationships into vector storage if needed
|
| 969 |
+
data_for_vdb = {
|
| 970 |
+
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
| 971 |
+
"src_id": dp["src_id"],
|
| 972 |
+
"tgt_id": dp["tgt_id"],
|
| 973 |
+
"content": dp["keywords"]
|
| 974 |
+
+ dp["src_id"]
|
| 975 |
+
+ dp["tgt_id"]
|
| 976 |
+
+ dp["description"],
|
|
|
|
|
|
|
|
|
|
| 977 |
}
|
| 978 |
+
for dp in all_relationships_data
|
| 979 |
+
}
|
| 980 |
+
await self.relationships_vdb.upsert(data_for_vdb)
|
| 981 |
+
|
| 982 |
finally:
|
| 983 |
if update_storage:
|
| 984 |
await self._insert_done()
|
| 985 |
|
| 986 |
+
def query(
|
| 987 |
+
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
|
| 988 |
+
) -> str | Iterator[str]:
|
| 989 |
+
"""
|
| 990 |
+
Perform a sync query.
|
| 991 |
+
|
| 992 |
+
Args:
|
| 993 |
+
query (str): The query to be executed.
|
| 994 |
+
param (QueryParam): Configuration parameters for query execution.
|
| 995 |
+
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
| 996 |
+
|
| 997 |
+
Returns:
|
| 998 |
+
str: The result of the query execution.
|
| 999 |
+
"""
|
| 1000 |
loop = always_get_an_event_loop()
|
| 1001 |
+
|
| 1002 |
+
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
|
| 1003 |
|
| 1004 |
async def aquery(
|
| 1005 |
+
self,
|
| 1006 |
+
query: str,
|
| 1007 |
+
param: QueryParam = QueryParam(),
|
| 1008 |
+
prompt: str | None = None,
|
| 1009 |
+
) -> str | AsyncIterator[str]:
|
| 1010 |
+
"""
|
| 1011 |
+
Perform a async query.
|
| 1012 |
+
|
| 1013 |
+
Args:
|
| 1014 |
+
query (str): The query to be executed.
|
| 1015 |
+
param (QueryParam): Configuration parameters for query execution.
|
| 1016 |
+
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
| 1017 |
+
|
| 1018 |
+
Returns:
|
| 1019 |
+
str: The result of the query execution.
|
| 1020 |
+
"""
|
| 1021 |
if param.mode in ["local", "global", "hybrid"]:
|
| 1022 |
response = await kg_query(
|
| 1023 |
query,
|
|
|
|
| 1097 |
|
| 1098 |
async def aquery_with_separate_keyword_extraction(
|
| 1099 |
self, query: str, prompt: str, param: QueryParam = QueryParam()
|
| 1100 |
+
) -> str | AsyncIterator[str]:
|
| 1101 |
"""
|
| 1102 |
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
|
| 1103 |
2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
|
|
|
|
| 1120 |
),
|
| 1121 |
)
|
| 1122 |
|
| 1123 |
+
param.hl_keywords = hl_keywords
|
| 1124 |
+
param.ll_keywords = ll_keywords
|
| 1125 |
|
| 1126 |
# ---------------------
|
| 1127 |
# STEP 2: Final Query Logic
|
|
|
|
| 1149 |
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
| 1150 |
),
|
| 1151 |
global_config=asdict(self),
|
| 1152 |
+
embedding_func=self.embedding_func,
|
| 1153 |
),
|
| 1154 |
)
|
| 1155 |
elif param.mode == "naive":
|
|
|
|
| 1198 |
return response
|
| 1199 |
|
| 1200 |
async def _query_done(self):
|
| 1201 |
+
await self.llm_response_cache.index_done_callback()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1202 |
|
| 1203 |
def delete_by_entity(self, entity_name: str):
|
| 1204 |
loop = always_get_an_event_loop()
|
|
|
|
| 1220 |
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
| 1221 |
|
| 1222 |
async def _delete_by_entity_done(self):
|
| 1223 |
+
await asyncio.gather(
|
| 1224 |
+
*[
|
| 1225 |
+
cast(StorageNameSpace, storage_inst).index_done_callback()
|
| 1226 |
+
for storage_inst in [ # type: ignore
|
| 1227 |
+
self.entities_vdb,
|
| 1228 |
+
self.relationships_vdb,
|
| 1229 |
+
self.chunk_entity_relation_graph,
|
| 1230 |
+
]
|
| 1231 |
+
]
|
| 1232 |
+
)
|
| 1233 |
|
| 1234 |
def _get_content_summary(self, content: str, max_length: int = 100) -> str:
|
| 1235 |
"""Get summary of document content
|
|
|
|
| 1254 |
"""
|
| 1255 |
return await self.doc_status.get_status_counts()
|
| 1256 |
|
| 1257 |
+
async def adelete_by_doc_id(self, doc_id: str) -> None:
|
| 1258 |
"""Delete a document and all its related data
|
| 1259 |
|
| 1260 |
Args:
|
|
|
|
| 1271 |
|
| 1272 |
# 2. Get all related chunks
|
| 1273 |
chunks = await self.text_chunks.get_by_id(doc_id)
|
| 1274 |
+
if not chunks:
|
| 1275 |
+
return
|
| 1276 |
+
|
| 1277 |
chunk_ids = list(chunks.keys())
|
| 1278 |
logger.debug(f"Found {len(chunk_ids)} chunks to delete")
|
| 1279 |
|
|
|
|
| 1444 |
except Exception as e:
|
| 1445 |
logger.error(f"Error while deleting document {doc_id}: {e}")
|
| 1446 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1447 |
async def get_entity_info(
|
| 1448 |
self, entity_name: str, include_vector_data: bool = False
|
| 1449 |
+
) -> dict[str, str | None | dict[str, str]]:
|
| 1450 |
"""Get detailed information of an entity
|
| 1451 |
|
| 1452 |
Args:
|
|
|
|
| 1466 |
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
| 1467 |
source_id = node_data.get("source_id") if node_data else None
|
| 1468 |
|
| 1469 |
+
result: dict[str, str | None | dict[str, str]] = {
|
| 1470 |
"entity_name": entity_name,
|
| 1471 |
"source_id": source_id,
|
| 1472 |
"graph_data": node_data,
|
|
|
|
| 1480 |
|
| 1481 |
return result
|
| 1482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1483 |
async def get_relation_info(
|
| 1484 |
self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
|
| 1485 |
):
|
|
|
|
| 1507 |
)
|
| 1508 |
source_id = edge_data.get("source_id") if edge_data else None
|
| 1509 |
|
| 1510 |
+
result: dict[str, str | None | dict[str, str]] = {
|
| 1511 |
"src_entity": src_entity,
|
| 1512 |
"tgt_entity": tgt_entity,
|
| 1513 |
"source_id": source_id,
|
|
|
|
| 1521 |
result["vector_data"] = vector_data[0] if vector_data else None
|
| 1522 |
|
| 1523 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lightrag/llm.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
-
from
|
|
|
|
|
|
|
| 2 |
from pydantic import BaseModel, Field
|
| 3 |
|
| 4 |
|
|
@@ -23,7 +25,7 @@ class Model(BaseModel):
|
|
| 23 |
...,
|
| 24 |
description="A function that generates the response from the llm. The response must be a string",
|
| 25 |
)
|
| 26 |
-
kwargs:
|
| 27 |
...,
|
| 28 |
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
| 29 |
)
|
|
@@ -57,7 +59,7 @@ class MultiModel:
|
|
| 57 |
```
|
| 58 |
"""
|
| 59 |
|
| 60 |
-
def __init__(self, models:
|
| 61 |
self._models = models
|
| 62 |
self._current_model = 0
|
| 63 |
|
|
@@ -66,7 +68,11 @@ class MultiModel:
|
|
| 66 |
return self._models[self._current_model]
|
| 67 |
|
| 68 |
async def llm_model_func(
|
| 69 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
) -> str:
|
| 71 |
kwargs.pop("model", None) # stop from overwriting the custom model name
|
| 72 |
kwargs.pop("keyword_extraction", None)
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Callable, Any
|
| 4 |
from pydantic import BaseModel, Field
|
| 5 |
|
| 6 |
|
|
|
|
| 25 |
...,
|
| 26 |
description="A function that generates the response from the llm. The response must be a string",
|
| 27 |
)
|
| 28 |
+
kwargs: dict[str, Any] = Field(
|
| 29 |
...,
|
| 30 |
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
| 31 |
)
|
|
|
|
| 59 |
```
|
| 60 |
"""
|
| 61 |
|
| 62 |
+
def __init__(self, models: list[Model]):
|
| 63 |
self._models = models
|
| 64 |
self._current_model = 0
|
| 65 |
|
|
|
|
| 68 |
return self._models[self._current_model]
|
| 69 |
|
| 70 |
async def llm_model_func(
|
| 71 |
+
self,
|
| 72 |
+
prompt: str,
|
| 73 |
+
system_prompt: str | None = None,
|
| 74 |
+
history_messages: list[dict[str, Any]] = [],
|
| 75 |
+
**kwargs: Any,
|
| 76 |
) -> str:
|
| 77 |
kwargs.pop("model", None) # stop from overwriting the custom model name
|
| 78 |
kwargs.pop("keyword_extraction", None)
|
lightrag/namespace.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from typing import Iterable
|
| 2 |
|
| 3 |
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
from typing import Iterable
|
| 4 |
|
| 5 |
|
lightrag/operate.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import json
|
| 3 |
import re
|
| 4 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
-
from typing import Any,
|
| 6 |
from collections import Counter, defaultdict
|
| 7 |
from .utils import (
|
| 8 |
logger,
|
|
@@ -36,7 +38,7 @@ import time
|
|
| 36 |
|
| 37 |
def chunking_by_token_size(
|
| 38 |
content: str,
|
| 39 |
-
split_by_character:
|
| 40 |
split_by_character_only: bool = False,
|
| 41 |
overlap_token_size: int = 128,
|
| 42 |
max_token_size: int = 1024,
|
|
@@ -237,25 +239,65 @@ async def _merge_edges_then_upsert(
|
|
| 237 |
|
| 238 |
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
| 239 |
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
| 250 |
description = GRAPH_FIELD_SEP.join(
|
| 251 |
-
sorted(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
)
|
| 253 |
keywords = GRAPH_FIELD_SEP.join(
|
| 254 |
-
sorted(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
)
|
| 256 |
source_id = GRAPH_FIELD_SEP.join(
|
| 257 |
-
set(
|
|
|
|
|
|
|
|
|
|
| 258 |
)
|
|
|
|
| 259 |
for need_insert_id in [src_id, tgt_id]:
|
| 260 |
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
| 261 |
await knowledge_graph_inst.upsert_node(
|
|
@@ -295,9 +337,9 @@ async def extract_entities(
|
|
| 295 |
knowledge_graph_inst: BaseGraphStorage,
|
| 296 |
entity_vdb: BaseVectorStorage,
|
| 297 |
relationships_vdb: BaseVectorStorage,
|
| 298 |
-
global_config: dict,
|
| 299 |
-
llm_response_cache: BaseKVStorage = None,
|
| 300 |
-
) ->
|
| 301 |
use_llm_func: callable = global_config["llm_model_func"]
|
| 302 |
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
| 303 |
enable_llm_cache_for_entity_extract: bool = global_config[
|
|
@@ -563,15 +605,15 @@ async def extract_entities(
|
|
| 563 |
|
| 564 |
|
| 565 |
async def kg_query(
|
| 566 |
-
query,
|
| 567 |
knowledge_graph_inst: BaseGraphStorage,
|
| 568 |
entities_vdb: BaseVectorStorage,
|
| 569 |
relationships_vdb: BaseVectorStorage,
|
| 570 |
text_chunks_db: BaseKVStorage,
|
| 571 |
query_param: QueryParam,
|
| 572 |
-
global_config: dict,
|
| 573 |
-
hashing_kv: BaseKVStorage = None,
|
| 574 |
-
prompt: str =
|
| 575 |
) -> str:
|
| 576 |
# Handle cache
|
| 577 |
use_model_func = global_config["llm_model_func"]
|
|
@@ -684,8 +726,8 @@ async def kg_query(
|
|
| 684 |
async def extract_keywords_only(
|
| 685 |
text: str,
|
| 686 |
param: QueryParam,
|
| 687 |
-
global_config: dict,
|
| 688 |
-
hashing_kv: BaseKVStorage = None,
|
| 689 |
) -> tuple[list[str], list[str]]:
|
| 690 |
"""
|
| 691 |
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
|
@@ -784,9 +826,9 @@ async def mix_kg_vector_query(
|
|
| 784 |
chunks_vdb: BaseVectorStorage,
|
| 785 |
text_chunks_db: BaseKVStorage,
|
| 786 |
query_param: QueryParam,
|
| 787 |
-
global_config: dict,
|
| 788 |
-
hashing_kv: BaseKVStorage = None,
|
| 789 |
-
) -> str:
|
| 790 |
"""
|
| 791 |
Hybrid retrieval implementation combining knowledge graph and vector search.
|
| 792 |
|
|
@@ -1551,13 +1593,13 @@ def combine_contexts(entities, relationships, sources):
|
|
| 1551 |
|
| 1552 |
|
| 1553 |
async def naive_query(
|
| 1554 |
-
query,
|
| 1555 |
chunks_vdb: BaseVectorStorage,
|
| 1556 |
text_chunks_db: BaseKVStorage,
|
| 1557 |
query_param: QueryParam,
|
| 1558 |
-
global_config: dict,
|
| 1559 |
-
hashing_kv: BaseKVStorage = None,
|
| 1560 |
-
):
|
| 1561 |
# Handle cache
|
| 1562 |
use_model_func = global_config["llm_model_func"]
|
| 1563 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
|
@@ -1664,9 +1706,9 @@ async def kg_query_with_keywords(
|
|
| 1664 |
relationships_vdb: BaseVectorStorage,
|
| 1665 |
text_chunks_db: BaseKVStorage,
|
| 1666 |
query_param: QueryParam,
|
| 1667 |
-
global_config: dict,
|
| 1668 |
-
hashing_kv: BaseKVStorage = None,
|
| 1669 |
-
) -> str:
|
| 1670 |
"""
|
| 1671 |
Refactored kg_query that does NOT extract keywords by itself.
|
| 1672 |
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import asyncio
|
| 4 |
import json
|
| 5 |
import re
|
| 6 |
from tqdm.asyncio import tqdm as tqdm_async
|
| 7 |
+
from typing import Any, AsyncIterator
|
| 8 |
from collections import Counter, defaultdict
|
| 9 |
from .utils import (
|
| 10 |
logger,
|
|
|
|
| 38 |
|
| 39 |
def chunking_by_token_size(
|
| 40 |
content: str,
|
| 41 |
+
split_by_character: str | None = None,
|
| 42 |
split_by_character_only: bool = False,
|
| 43 |
overlap_token_size: int = 128,
|
| 44 |
max_token_size: int = 1024,
|
|
|
|
| 239 |
|
| 240 |
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
| 241 |
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
| 242 |
+
# Handle the case where get_edge returns None or missing fields
|
| 243 |
+
if already_edge:
|
| 244 |
+
# Get weight with default 0.0 if missing
|
| 245 |
+
if "weight" in already_edge:
|
| 246 |
+
already_weights.append(already_edge["weight"])
|
| 247 |
+
else:
|
| 248 |
+
logger.warning(
|
| 249 |
+
f"Edge between {src_id} and {tgt_id} missing weight field"
|
| 250 |
+
)
|
| 251 |
+
already_weights.append(0.0)
|
| 252 |
+
|
| 253 |
+
# Get source_id with empty string default if missing or None
|
| 254 |
+
if "source_id" in already_edge and already_edge["source_id"] is not None:
|
| 255 |
+
already_source_ids.extend(
|
| 256 |
+
split_string_by_multi_markers(
|
| 257 |
+
already_edge["source_id"], [GRAPH_FIELD_SEP]
|
| 258 |
+
)
|
| 259 |
+
)
|
| 260 |
|
| 261 |
+
# Get description with empty string default if missing or None
|
| 262 |
+
if (
|
| 263 |
+
"description" in already_edge
|
| 264 |
+
and already_edge["description"] is not None
|
| 265 |
+
):
|
| 266 |
+
already_description.append(already_edge["description"])
|
| 267 |
+
|
| 268 |
+
# Get keywords with empty string default if missing or None
|
| 269 |
+
if "keywords" in already_edge and already_edge["keywords"] is not None:
|
| 270 |
+
already_keywords.extend(
|
| 271 |
+
split_string_by_multi_markers(
|
| 272 |
+
already_edge["keywords"], [GRAPH_FIELD_SEP]
|
| 273 |
+
)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Process edges_data with None checks
|
| 277 |
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
| 278 |
description = GRAPH_FIELD_SEP.join(
|
| 279 |
+
sorted(
|
| 280 |
+
set(
|
| 281 |
+
[dp["description"] for dp in edges_data if dp.get("description")]
|
| 282 |
+
+ already_description
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
)
|
| 286 |
keywords = GRAPH_FIELD_SEP.join(
|
| 287 |
+
sorted(
|
| 288 |
+
set(
|
| 289 |
+
[dp["keywords"] for dp in edges_data if dp.get("keywords")]
|
| 290 |
+
+ already_keywords
|
| 291 |
+
)
|
| 292 |
+
)
|
| 293 |
)
|
| 294 |
source_id = GRAPH_FIELD_SEP.join(
|
| 295 |
+
set(
|
| 296 |
+
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
|
| 297 |
+
+ already_source_ids
|
| 298 |
+
)
|
| 299 |
)
|
| 300 |
+
|
| 301 |
for need_insert_id in [src_id, tgt_id]:
|
| 302 |
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
| 303 |
await knowledge_graph_inst.upsert_node(
|
|
|
|
| 337 |
knowledge_graph_inst: BaseGraphStorage,
|
| 338 |
entity_vdb: BaseVectorStorage,
|
| 339 |
relationships_vdb: BaseVectorStorage,
|
| 340 |
+
global_config: dict[str, str],
|
| 341 |
+
llm_response_cache: BaseKVStorage | None = None,
|
| 342 |
+
) -> BaseGraphStorage | None:
|
| 343 |
use_llm_func: callable = global_config["llm_model_func"]
|
| 344 |
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
| 345 |
enable_llm_cache_for_entity_extract: bool = global_config[
|
|
|
|
| 605 |
|
| 606 |
|
| 607 |
async def kg_query(
|
| 608 |
+
query: str,
|
| 609 |
knowledge_graph_inst: BaseGraphStorage,
|
| 610 |
entities_vdb: BaseVectorStorage,
|
| 611 |
relationships_vdb: BaseVectorStorage,
|
| 612 |
text_chunks_db: BaseKVStorage,
|
| 613 |
query_param: QueryParam,
|
| 614 |
+
global_config: dict[str, str],
|
| 615 |
+
hashing_kv: BaseKVStorage | None = None,
|
| 616 |
+
prompt: str | None = None,
|
| 617 |
) -> str:
|
| 618 |
# Handle cache
|
| 619 |
use_model_func = global_config["llm_model_func"]
|
|
|
|
| 726 |
async def extract_keywords_only(
|
| 727 |
text: str,
|
| 728 |
param: QueryParam,
|
| 729 |
+
global_config: dict[str, str],
|
| 730 |
+
hashing_kv: BaseKVStorage | None = None,
|
| 731 |
) -> tuple[list[str], list[str]]:
|
| 732 |
"""
|
| 733 |
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
|
|
|
| 826 |
chunks_vdb: BaseVectorStorage,
|
| 827 |
text_chunks_db: BaseKVStorage,
|
| 828 |
query_param: QueryParam,
|
| 829 |
+
global_config: dict[str, str],
|
| 830 |
+
hashing_kv: BaseKVStorage | None = None,
|
| 831 |
+
) -> str | AsyncIterator[str]:
|
| 832 |
"""
|
| 833 |
Hybrid retrieval implementation combining knowledge graph and vector search.
|
| 834 |
|
|
|
|
| 1593 |
|
| 1594 |
|
| 1595 |
async def naive_query(
|
| 1596 |
+
query: str,
|
| 1597 |
chunks_vdb: BaseVectorStorage,
|
| 1598 |
text_chunks_db: BaseKVStorage,
|
| 1599 |
query_param: QueryParam,
|
| 1600 |
+
global_config: dict[str, str],
|
| 1601 |
+
hashing_kv: BaseKVStorage | None = None,
|
| 1602 |
+
) -> str | AsyncIterator[str]:
|
| 1603 |
# Handle cache
|
| 1604 |
use_model_func = global_config["llm_model_func"]
|
| 1605 |
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
|
|
|
| 1706 |
relationships_vdb: BaseVectorStorage,
|
| 1707 |
text_chunks_db: BaseKVStorage,
|
| 1708 |
query_param: QueryParam,
|
| 1709 |
+
global_config: dict[str, str],
|
| 1710 |
+
hashing_kv: BaseKVStorage | None = None,
|
| 1711 |
+
) -> str | AsyncIterator[str]:
|
| 1712 |
"""
|
| 1713 |
Refactored kg_query that does NOT extract keywords by itself.
|
| 1714 |
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
lightrag/prompt.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
GRAPH_FIELD_SEP = "<SEP>"
|
| 2 |
|
| 3 |
PROMPTS = {}
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
GRAPH_FIELD_SEP = "<SEP>"
|
| 4 |
|
| 5 |
PROMPTS = {}
|
lightrag/types.py
CHANGED
|
@@ -1,26 +1,28 @@
|
|
|
|
|
|
|
|
| 1 |
from pydantic import BaseModel
|
| 2 |
-
from typing import
|
| 3 |
|
| 4 |
|
| 5 |
class GPTKeywordExtractionFormat(BaseModel):
|
| 6 |
-
high_level_keywords:
|
| 7 |
-
low_level_keywords:
|
| 8 |
|
| 9 |
|
| 10 |
class KnowledgeGraphNode(BaseModel):
|
| 11 |
id: str
|
| 12 |
-
labels:
|
| 13 |
-
properties:
|
| 14 |
|
| 15 |
|
| 16 |
class KnowledgeGraphEdge(BaseModel):
|
| 17 |
id: str
|
| 18 |
-
type: str
|
| 19 |
source: str # id of source node
|
| 20 |
target: str # id of target node
|
| 21 |
-
properties:
|
| 22 |
|
| 23 |
|
| 24 |
class KnowledgeGraph(BaseModel):
|
| 25 |
-
nodes:
|
| 26 |
-
edges:
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
from pydantic import BaseModel
|
| 4 |
+
from typing import Any, Optional
|
| 5 |
|
| 6 |
|
| 7 |
class GPTKeywordExtractionFormat(BaseModel):
|
| 8 |
+
high_level_keywords: list[str]
|
| 9 |
+
low_level_keywords: list[str]
|
| 10 |
|
| 11 |
|
| 12 |
class KnowledgeGraphNode(BaseModel):
|
| 13 |
id: str
|
| 14 |
+
labels: list[str]
|
| 15 |
+
properties: dict[str, Any] # anything else goes here
|
| 16 |
|
| 17 |
|
| 18 |
class KnowledgeGraphEdge(BaseModel):
|
| 19 |
id: str
|
| 20 |
+
type: Optional[str]
|
| 21 |
source: str # id of source node
|
| 22 |
target: str # id of target node
|
| 23 |
+
properties: dict[str, Any] # anything else goes here
|
| 24 |
|
| 25 |
|
| 26 |
class KnowledgeGraph(BaseModel):
|
| 27 |
+
nodes: list[KnowledgeGraphNode] = []
|
| 28 |
+
edges: list[KnowledgeGraphEdge] = []
|
lightrag/utils.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import html
|
| 3 |
import io
|
|
@@ -9,7 +11,7 @@ import re
|
|
| 9 |
from dataclasses import dataclass
|
| 10 |
from functools import wraps
|
| 11 |
from hashlib import md5
|
| 12 |
-
from typing import Any,
|
| 13 |
import xml.etree.ElementTree as ET
|
| 14 |
import bs4
|
| 15 |
|
|
@@ -67,12 +69,12 @@ class EmbeddingFunc:
|
|
| 67 |
|
| 68 |
@dataclass
|
| 69 |
class ReasoningResponse:
|
| 70 |
-
reasoning_content: str
|
| 71 |
response_content: str
|
| 72 |
tag: str
|
| 73 |
|
| 74 |
|
| 75 |
-
def locate_json_string_body_from_string(content: str) ->
|
| 76 |
"""Locate the JSON string body from a string"""
|
| 77 |
try:
|
| 78 |
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
|
@@ -109,7 +111,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
|
|
| 109 |
raise e from None
|
| 110 |
|
| 111 |
|
| 112 |
-
def compute_args_hash(*args, cache_type: str = None) -> str:
|
| 113 |
"""Compute a hash for the given arguments.
|
| 114 |
Args:
|
| 115 |
*args: Arguments to hash
|
|
@@ -128,7 +130,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
|
|
| 128 |
return hashlib.md5(args_str.encode()).hexdigest()
|
| 129 |
|
| 130 |
|
| 131 |
-
def compute_mdhash_id(content, prefix: str = ""):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
return prefix + md5(content.encode()).hexdigest()
|
| 133 |
|
| 134 |
|
|
@@ -215,11 +222,13 @@ def clean_str(input: Any) -> str:
|
|
| 215 |
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
| 216 |
|
| 217 |
|
| 218 |
-
def is_float_regex(value):
|
| 219 |
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
| 220 |
|
| 221 |
|
| 222 |
-
def truncate_list_by_token_size(
|
|
|
|
|
|
|
| 223 |
"""Truncate a list of data by token size"""
|
| 224 |
if max_token_size <= 0:
|
| 225 |
return []
|
|
@@ -231,7 +240,7 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
|
| 231 |
return list_data
|
| 232 |
|
| 233 |
|
| 234 |
-
def list_of_list_to_csv(data:
|
| 235 |
output = io.StringIO()
|
| 236 |
writer = csv.writer(
|
| 237 |
output,
|
|
@@ -244,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
|
|
| 244 |
return output.getvalue()
|
| 245 |
|
| 246 |
|
| 247 |
-
def csv_string_to_list(csv_string: str) ->
|
| 248 |
# Clean the string by removing NUL characters
|
| 249 |
cleaned_string = csv_string.replace("\0", "")
|
| 250 |
|
|
@@ -329,7 +338,7 @@ def xml_to_json(xml_file):
|
|
| 329 |
return None
|
| 330 |
|
| 331 |
|
| 332 |
-
def process_combine_contexts(hl, ll):
|
| 333 |
header = None
|
| 334 |
list_hl = csv_string_to_list(hl.strip())
|
| 335 |
list_ll = csv_string_to_list(ll.strip())
|
|
@@ -375,7 +384,7 @@ async def get_best_cached_response(
|
|
| 375 |
llm_func=None,
|
| 376 |
original_prompt=None,
|
| 377 |
cache_type=None,
|
| 378 |
-
) ->
|
| 379 |
logger.debug(
|
| 380 |
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
| 381 |
)
|
|
@@ -479,7 +488,7 @@ def cosine_similarity(v1, v2):
|
|
| 479 |
return dot_product / (norm1 * norm2)
|
| 480 |
|
| 481 |
|
| 482 |
-
def quantize_embedding(embedding:
|
| 483 |
"""Quantize embedding to specified bits"""
|
| 484 |
# Convert list to numpy array if needed
|
| 485 |
if isinstance(embedding, list):
|
|
@@ -570,9 +579,9 @@ class CacheData:
|
|
| 570 |
args_hash: str
|
| 571 |
content: str
|
| 572 |
prompt: str
|
| 573 |
-
quantized:
|
| 574 |
-
min_val:
|
| 575 |
-
max_val:
|
| 576 |
mode: str = "default"
|
| 577 |
cache_type: str = "query"
|
| 578 |
|
|
@@ -635,7 +644,9 @@ def exists_func(obj, func_name: str) -> bool:
|
|
| 635 |
return False
|
| 636 |
|
| 637 |
|
| 638 |
-
def get_conversation_turns(
|
|
|
|
|
|
|
| 639 |
"""
|
| 640 |
Process conversation history to get the specified number of complete turns.
|
| 641 |
|
|
@@ -647,8 +658,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
|
| 647 |
Formatted string of the conversation history
|
| 648 |
"""
|
| 649 |
# Group messages into turns
|
| 650 |
-
turns = []
|
| 651 |
-
messages = []
|
| 652 |
|
| 653 |
# First, filter out keyword extraction messages
|
| 654 |
for msg in conversation_history:
|
|
@@ -682,7 +693,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
|
| 682 |
turns = turns[-num_turns:]
|
| 683 |
|
| 684 |
# Format the turns into a string
|
| 685 |
-
formatted_turns = []
|
| 686 |
for turn in turns:
|
| 687 |
formatted_turns.extend(
|
| 688 |
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import asyncio
|
| 4 |
import html
|
| 5 |
import io
|
|
|
|
| 11 |
from dataclasses import dataclass
|
| 12 |
from functools import wraps
|
| 13 |
from hashlib import md5
|
| 14 |
+
from typing import Any, Callable
|
| 15 |
import xml.etree.ElementTree as ET
|
| 16 |
import bs4
|
| 17 |
|
|
|
|
| 69 |
|
| 70 |
@dataclass
|
| 71 |
class ReasoningResponse:
|
| 72 |
+
reasoning_content: str | None
|
| 73 |
response_content: str
|
| 74 |
tag: str
|
| 75 |
|
| 76 |
|
| 77 |
+
def locate_json_string_body_from_string(content: str) -> str | None:
|
| 78 |
"""Locate the JSON string body from a string"""
|
| 79 |
try:
|
| 80 |
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
|
|
|
| 111 |
raise e from None
|
| 112 |
|
| 113 |
|
| 114 |
+
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
| 115 |
"""Compute a hash for the given arguments.
|
| 116 |
Args:
|
| 117 |
*args: Arguments to hash
|
|
|
|
| 130 |
return hashlib.md5(args_str.encode()).hexdigest()
|
| 131 |
|
| 132 |
|
| 133 |
+
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
| 134 |
+
"""
|
| 135 |
+
Compute a unique ID for a given content string.
|
| 136 |
+
|
| 137 |
+
The ID is a combination of the given prefix and the MD5 hash of the content string.
|
| 138 |
+
"""
|
| 139 |
return prefix + md5(content.encode()).hexdigest()
|
| 140 |
|
| 141 |
|
|
|
|
| 222 |
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
| 223 |
|
| 224 |
|
| 225 |
+
def is_float_regex(value: str) -> bool:
|
| 226 |
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
| 227 |
|
| 228 |
|
| 229 |
+
def truncate_list_by_token_size(
|
| 230 |
+
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
|
| 231 |
+
) -> list[int]:
|
| 232 |
"""Truncate a list of data by token size"""
|
| 233 |
if max_token_size <= 0:
|
| 234 |
return []
|
|
|
|
| 240 |
return list_data
|
| 241 |
|
| 242 |
|
| 243 |
+
def list_of_list_to_csv(data: list[list[str]]) -> str:
|
| 244 |
output = io.StringIO()
|
| 245 |
writer = csv.writer(
|
| 246 |
output,
|
|
|
|
| 253 |
return output.getvalue()
|
| 254 |
|
| 255 |
|
| 256 |
+
def csv_string_to_list(csv_string: str) -> list[list[str]]:
|
| 257 |
# Clean the string by removing NUL characters
|
| 258 |
cleaned_string = csv_string.replace("\0", "")
|
| 259 |
|
|
|
|
| 338 |
return None
|
| 339 |
|
| 340 |
|
| 341 |
+
def process_combine_contexts(hl: str, ll: str):
|
| 342 |
header = None
|
| 343 |
list_hl = csv_string_to_list(hl.strip())
|
| 344 |
list_ll = csv_string_to_list(ll.strip())
|
|
|
|
| 384 |
llm_func=None,
|
| 385 |
original_prompt=None,
|
| 386 |
cache_type=None,
|
| 387 |
+
) -> str | None:
|
| 388 |
logger.debug(
|
| 389 |
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
| 390 |
)
|
|
|
|
| 488 |
return dot_product / (norm1 * norm2)
|
| 489 |
|
| 490 |
|
| 491 |
+
def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
|
| 492 |
"""Quantize embedding to specified bits"""
|
| 493 |
# Convert list to numpy array if needed
|
| 494 |
if isinstance(embedding, list):
|
|
|
|
| 579 |
args_hash: str
|
| 580 |
content: str
|
| 581 |
prompt: str
|
| 582 |
+
quantized: np.ndarray | None = None
|
| 583 |
+
min_val: float | None = None
|
| 584 |
+
max_val: float | None = None
|
| 585 |
mode: str = "default"
|
| 586 |
cache_type: str = "query"
|
| 587 |
|
|
|
|
| 644 |
return False
|
| 645 |
|
| 646 |
|
| 647 |
+
def get_conversation_turns(
|
| 648 |
+
conversation_history: list[dict[str, Any]], num_turns: int
|
| 649 |
+
) -> str:
|
| 650 |
"""
|
| 651 |
Process conversation history to get the specified number of complete turns.
|
| 652 |
|
|
|
|
| 658 |
Formatted string of the conversation history
|
| 659 |
"""
|
| 660 |
# Group messages into turns
|
| 661 |
+
turns: list[list[dict[str, Any]]] = []
|
| 662 |
+
messages: list[dict[str, Any]] = []
|
| 663 |
|
| 664 |
# First, filter out keyword extraction messages
|
| 665 |
for msg in conversation_history:
|
|
|
|
| 693 |
turns = turns[-num_turns:]
|
| 694 |
|
| 695 |
# Format the turns into a string
|
| 696 |
+
formatted_turns: list[str] = []
|
| 697 |
for turn in turns:
|
| 698 |
formatted_turns.extend(
|
| 699 |
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
lightrag_webui/src/components/PropertiesView.tsx
CHANGED
|
@@ -200,7 +200,7 @@ const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => {
|
|
| 200 |
<label className="text-md pl-1 font-bold tracking-wide text-teal-600">Relationship</label>
|
| 201 |
<div className="bg-primary/5 max-h-96 overflow-auto rounded p-1">
|
| 202 |
<PropertyRow name={'Id'} value={edge.id} />
|
| 203 |
-
<PropertyRow name={'Type'} value={edge.type} />
|
| 204 |
<PropertyRow
|
| 205 |
name={'Source'}
|
| 206 |
value={edge.sourceNode ? edge.sourceNode.labels.join(', ') : edge.source}
|
|
|
|
| 200 |
<label className="text-md pl-1 font-bold tracking-wide text-teal-600">Relationship</label>
|
| 201 |
<div className="bg-primary/5 max-h-96 overflow-auto rounded p-1">
|
| 202 |
<PropertyRow name={'Id'} value={edge.id} />
|
| 203 |
+
{edge.type && <PropertyRow name={'Type'} value={edge.type} />}
|
| 204 |
<PropertyRow
|
| 205 |
name={'Source'}
|
| 206 |
value={edge.sourceNode ? edge.sourceNode.labels.join(', ') : edge.source}
|
lightrag_webui/src/hooks/useLightragGraph.tsx
CHANGED
|
@@ -24,7 +24,7 @@ const validateGraph = (graph: RawGraph) => {
|
|
| 24 |
}
|
| 25 |
|
| 26 |
for (const edge of graph.edges) {
|
| 27 |
-
if (!edge.id || !edge.source || !edge.target
|
| 28 |
return false
|
| 29 |
}
|
| 30 |
}
|
|
@@ -88,6 +88,14 @@ const fetchGraph = async (label: string) => {
|
|
| 88 |
if (source !== undefined && source !== undefined) {
|
| 89 |
const sourceNode = rawData.nodes[source]
|
| 90 |
const targetNode = rawData.nodes[target]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
sourceNode.degree += 1
|
| 92 |
targetNode.degree += 1
|
| 93 |
}
|
|
@@ -146,7 +154,7 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
|
|
| 146 |
|
| 147 |
for (const rawEdge of rawGraph?.edges ?? []) {
|
| 148 |
rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
|
| 149 |
-
label: rawEdge.type
|
| 150 |
})
|
| 151 |
}
|
| 152 |
|
|
|
|
| 24 |
}
|
| 25 |
|
| 26 |
for (const edge of graph.edges) {
|
| 27 |
+
if (!edge.id || !edge.source || !edge.target) {
|
| 28 |
return false
|
| 29 |
}
|
| 30 |
}
|
|
|
|
| 88 |
if (source !== undefined && source !== undefined) {
|
| 89 |
const sourceNode = rawData.nodes[source]
|
| 90 |
const targetNode = rawData.nodes[target]
|
| 91 |
+
if (!sourceNode) {
|
| 92 |
+
console.error(`Source node ${edge.source} is undefined`)
|
| 93 |
+
continue
|
| 94 |
+
}
|
| 95 |
+
if (!targetNode) {
|
| 96 |
+
console.error(`Target node ${edge.target} is undefined`)
|
| 97 |
+
continue
|
| 98 |
+
}
|
| 99 |
sourceNode.degree += 1
|
| 100 |
targetNode.degree += 1
|
| 101 |
}
|
|
|
|
| 154 |
|
| 155 |
for (const rawEdge of rawGraph?.edges ?? []) {
|
| 156 |
rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
|
| 157 |
+
label: rawEdge.type || undefined
|
| 158 |
})
|
| 159 |
}
|
| 160 |
|
lightrag_webui/src/stores/graph.ts
CHANGED
|
@@ -19,7 +19,7 @@ export type RawEdgeType = {
|
|
| 19 |
id: string
|
| 20 |
source: string
|
| 21 |
target: string
|
| 22 |
-
type
|
| 23 |
properties: Record<string, any>
|
| 24 |
|
| 25 |
dynamicId: string
|
|
|
|
| 19 |
id: string
|
| 20 |
source: string
|
| 21 |
target: string
|
| 22 |
+
type?: string
|
| 23 |
properties: Record<string, any>
|
| 24 |
|
| 25 |
dynamicId: string
|