yangdx commited on
Commit
55aa606
·
2 Parent(s): 0564adf 0a5a0dc

Merge branch 'main' into add-env-settings

Browse files
.dockerignore CHANGED
@@ -1 +1,63 @@
1
- .env
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-related files and directories
2
+ __pycache__
3
+ .cache
4
+
5
+ # Virtual environment directories
6
+ *.venv
7
+
8
+ # Env
9
+ env/
10
+ *.env*
11
+ .env_example
12
+
13
+ # Distribution / build files
14
+ site
15
+ dist/
16
+ build/
17
+ .eggs/
18
+ *.egg-info/
19
+ *.tgz
20
+ *.tar.gz
21
+
22
+ # Exclude siles and folders
23
+ *.yml
24
+ .dockerignore
25
+ Dockerfile
26
+ Makefile
27
+
28
+ # Exclude other projects
29
+ /tests
30
+ /scripts
31
+
32
+ # Python version manager file
33
+ .python-version
34
+
35
+ # Reports
36
+ *.coverage/
37
+ *.log
38
+ log/
39
+ *.logfire
40
+
41
+ # Cache
42
+ .cache/
43
+ .mypy_cache
44
+ .pytest_cache
45
+ .ruff_cache
46
+ .gradio
47
+ .logfire
48
+ temp/
49
+
50
+ # MacOS-related files
51
+ .DS_Store
52
+
53
+ # VS Code settings (local configuration files)
54
+ .vscode
55
+
56
+ # file
57
+ TODO.md
58
+
59
+ # Exclude Git-related files
60
+ .git
61
+ .github
62
+ .gitignore
63
+ .pre-commit-config.yaml
.gitignore CHANGED
@@ -35,23 +35,27 @@ temp/
35
 
36
  # IDE / Editor Files
37
  .idea/
38
- dist/
39
- env/
 
 
40
  local_neo4jWorkDir/
41
  neo4jWorkDir/
42
- ignore_this.txt
43
- .venv/
44
- *.ignore.*
45
- .ruff_cache/
46
- gui/
47
- *.log
48
- .vscode
49
- inputs
50
- rag_storage
51
- .env
52
- venv/
53
  examples/input/
54
  examples/output/
 
 
55
  .DS_Store
56
- #Remove config.ini from repo
57
- *.ini
 
 
 
 
 
 
 
 
35
 
36
  # IDE / Editor Files
37
  .idea/
38
+ .vscode/
39
+ .vscode/settings.json
40
+
41
+ # Framework-specific files
42
  local_neo4jWorkDir/
43
  neo4jWorkDir/
44
+
45
+ # Data & Storage
46
+ inputs/
47
+ rag_storage/
 
 
 
 
 
 
 
48
  examples/input/
49
  examples/output/
50
+
51
+ # Miscellaneous
52
  .DS_Store
53
+ TODO.md
54
+ ignore_this.txt
55
+ *.ignore.*
56
+
57
+ # Project-specific files
58
+ dickens/
59
+ book.txt
60
+ lightrag-dev/
61
+ gui/
README.md CHANGED
@@ -237,7 +237,7 @@ rag = LightRAG(
237
 
238
  * If you want to use Hugging Face models, you only need to set LightRAG as follows:
239
  ```python
240
- from lightrag.llm import hf_model_complete, hf_embedding
241
  from transformers import AutoModel, AutoTokenizer
242
  from lightrag.utils import EmbeddingFunc
243
 
@@ -250,7 +250,7 @@ rag = LightRAG(
250
  embedding_func=EmbeddingFunc(
251
  embedding_dim=384,
252
  max_token_size=5000,
253
- func=lambda texts: hf_embedding(
254
  texts,
255
  tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
256
  embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
@@ -428,9 +428,9 @@ And using a routine to process news documents.
428
 
429
  ```python
430
  rag = LightRAG(..)
431
- await rag.apipeline_enqueue_documents(string_or_strings)
432
  # Your routine in loop
433
- await rag.apipeline_process_enqueue_documents(string_or_strings)
434
  ```
435
 
436
  ### Separate Keyword Extraction
 
237
 
238
  * If you want to use Hugging Face models, you only need to set LightRAG as follows:
239
  ```python
240
+ from lightrag.llm import hf_model_complete, hf_embed
241
  from transformers import AutoModel, AutoTokenizer
242
  from lightrag.utils import EmbeddingFunc
243
 
 
250
  embedding_func=EmbeddingFunc(
251
  embedding_dim=384,
252
  max_token_size=5000,
253
+ func=lambda texts: hf_embed(
254
  texts,
255
  tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
256
  embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
 
428
 
429
  ```python
430
  rag = LightRAG(..)
431
+ await rag.apipeline_enqueue_documents(input)
432
  # Your routine in loop
433
+ await rag.apipeline_process_enqueue_documents(input)
434
  ```
435
 
436
  ### Separate Keyword Extraction
examples/lightrag_oracle_demo.py CHANGED
@@ -113,7 +113,24 @@ async def main():
113
  )
114
 
115
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
116
- rag.set_storage_client(db_client=oracle_db)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  # Extract and Insert into LightRAG storage
119
  with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
 
113
  )
114
 
115
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
116
+
117
+ for storage in [
118
+ rag.vector_db_storage_cls,
119
+ rag.graph_storage_cls,
120
+ rag.doc_status,
121
+ rag.full_docs,
122
+ rag.text_chunks,
123
+ rag.llm_response_cache,
124
+ rag.key_string_value_json_storage_cls,
125
+ rag.chunks_vdb,
126
+ rag.relationships_vdb,
127
+ rag.entities_vdb,
128
+ rag.graph_storage_cls,
129
+ rag.chunk_entity_relation_graph,
130
+ rag.llm_response_cache,
131
+ ]:
132
+ # set client
133
+ storage.db = oracle_db
134
 
135
  # Extract and Insert into LightRAG storage
136
  with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
examples/test_chromadb.py CHANGED
@@ -15,6 +15,12 @@ if not os.path.exists(WORKING_DIR):
15
  os.mkdir(WORKING_DIR)
16
 
17
  # ChromaDB Configuration
 
 
 
 
 
 
18
  CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
19
  CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
20
  CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
@@ -60,30 +66,50 @@ async def create_embedding_function_instance():
60
 
61
  async def initialize_rag():
62
  embedding_func_instance = await create_embedding_function_instance()
63
-
64
- return LightRAG(
65
- working_dir=WORKING_DIR,
66
- llm_model_func=gpt_4o_mini_complete,
67
- embedding_func=embedding_func_instance,
68
- vector_storage="ChromaVectorDBStorage",
69
- log_level="DEBUG",
70
- embedding_batch_num=32,
71
- vector_db_storage_cls_kwargs={
72
- "host": CHROMADB_HOST,
73
- "port": CHROMADB_PORT,
74
- "auth_token": CHROMADB_AUTH_TOKEN,
75
- "auth_provider": CHROMADB_AUTH_PROVIDER,
76
- "auth_header_name": CHROMADB_AUTH_HEADER,
77
- "collection_settings": {
78
- "hnsw:space": "cosine",
79
- "hnsw:construction_ef": 128,
80
- "hnsw:search_ef": 128,
81
- "hnsw:M": 16,
82
- "hnsw:batch_size": 100,
83
- "hnsw:sync_threshold": 1000,
84
  },
85
- },
86
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  # Run the initialization
 
15
  os.mkdir(WORKING_DIR)
16
 
17
  # ChromaDB Configuration
18
+ CHROMADB_USE_LOCAL_PERSISTENT = False
19
+ # Local PersistentClient Configuration
20
+ CHROMADB_LOCAL_PATH = os.environ.get(
21
+ "CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")
22
+ )
23
+ # Remote HttpClient Configuration
24
  CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
25
  CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
26
  CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
 
66
 
67
  async def initialize_rag():
68
  embedding_func_instance = await create_embedding_function_instance()
69
+ if CHROMADB_USE_LOCAL_PERSISTENT:
70
+ return LightRAG(
71
+ working_dir=WORKING_DIR,
72
+ llm_model_func=gpt_4o_mini_complete,
73
+ embedding_func=embedding_func_instance,
74
+ vector_storage="ChromaVectorDBStorage",
75
+ log_level="DEBUG",
76
+ embedding_batch_num=32,
77
+ vector_db_storage_cls_kwargs={
78
+ "local_path": CHROMADB_LOCAL_PATH,
79
+ "collection_settings": {
80
+ "hnsw:space": "cosine",
81
+ "hnsw:construction_ef": 128,
82
+ "hnsw:search_ef": 128,
83
+ "hnsw:M": 16,
84
+ "hnsw:batch_size": 100,
85
+ "hnsw:sync_threshold": 1000,
86
+ },
 
 
 
87
  },
88
+ )
89
+ else:
90
+ return LightRAG(
91
+ working_dir=WORKING_DIR,
92
+ llm_model_func=gpt_4o_mini_complete,
93
+ embedding_func=embedding_func_instance,
94
+ vector_storage="ChromaVectorDBStorage",
95
+ log_level="DEBUG",
96
+ embedding_batch_num=32,
97
+ vector_db_storage_cls_kwargs={
98
+ "host": CHROMADB_HOST,
99
+ "port": CHROMADB_PORT,
100
+ "auth_token": CHROMADB_AUTH_TOKEN,
101
+ "auth_provider": CHROMADB_AUTH_PROVIDER,
102
+ "auth_header_name": CHROMADB_AUTH_HEADER,
103
+ "collection_settings": {
104
+ "hnsw:space": "cosine",
105
+ "hnsw:construction_ef": 128,
106
+ "hnsw:search_ef": 128,
107
+ "hnsw:M": 16,
108
+ "hnsw:batch_size": 100,
109
+ "hnsw:sync_threshold": 1000,
110
+ },
111
+ },
112
+ )
113
 
114
 
115
  # Run the initialization
external_bindings/OpenWebuiTool/openwebui_tool.py DELETED
@@ -1,358 +0,0 @@
1
- """
2
- OpenWebui Lightrag Integration Tool
3
- ==================================
4
-
5
- This tool enables the integration and use of Lightrag within the OpenWebui environment,
6
- providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
7
-
8
- Author: ParisNeo (parisneoai@gmail.com)
9
- Social:
10
- - Twitter: @ParisNeo_AI
11
- - Reddit: r/lollms
12
- - Instagram: https://www.instagram.com/parisneo_ai/
13
-
14
- License: Apache 2.0
15
- Copyright (c) 2024-2025 ParisNeo
16
-
17
- This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
18
- For more information, visit: https://github.com/ParisNeo/lollms
19
-
20
- Requirements:
21
- - Python 3.8+
22
- - OpenWebui
23
- - Lightrag
24
- """
25
-
26
- # Tool version
27
- __version__ = "1.0.0"
28
- __author__ = "ParisNeo"
29
- __author_email__ = "parisneoai@gmail.com"
30
- __description__ = "Lightrag integration for OpenWebui"
31
-
32
-
33
- import requests
34
- import json
35
- from pydantic import BaseModel, Field
36
- from typing import Callable, Any, Literal, Union, List, Tuple
37
-
38
-
39
- class StatusEventEmitter:
40
- def __init__(self, event_emitter: Callable[[dict], Any] = None):
41
- self.event_emitter = event_emitter
42
-
43
- async def emit(self, description="Unknown State", status="in_progress", done=False):
44
- if self.event_emitter:
45
- await self.event_emitter(
46
- {
47
- "type": "status",
48
- "data": {
49
- "status": status,
50
- "description": description,
51
- "done": done,
52
- },
53
- }
54
- )
55
-
56
-
57
- class MessageEventEmitter:
58
- def __init__(self, event_emitter: Callable[[dict], Any] = None):
59
- self.event_emitter = event_emitter
60
-
61
- async def emit(self, content="Some message"):
62
- if self.event_emitter:
63
- await self.event_emitter(
64
- {
65
- "type": "message",
66
- "data": {
67
- "content": content,
68
- },
69
- }
70
- )
71
-
72
-
73
- class Tools:
74
- class Valves(BaseModel):
75
- LIGHTRAG_SERVER_URL: str = Field(
76
- default="http://localhost:9621/query",
77
- description="The base URL for the LightRag server",
78
- )
79
- MODE: Literal["naive", "local", "global", "hybrid"] = Field(
80
- default="hybrid",
81
- description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
82
- )
83
- ONLY_NEED_CONTEXT: bool = Field(
84
- default=False,
85
- description="If True, only the context is needed from the LightRag response",
86
- )
87
- DEBUG_MODE: bool = Field(
88
- default=False,
89
- description="If True, debugging information will be emitted",
90
- )
91
- KEY: str = Field(
92
- default="",
93
- description="Optional Bearer Key for authentication",
94
- )
95
- MAX_ENTITIES: int = Field(
96
- default=5,
97
- description="Maximum number of entities to keep",
98
- )
99
- MAX_RELATIONSHIPS: int = Field(
100
- default=5,
101
- description="Maximum number of relationships to keep",
102
- )
103
- MAX_SOURCES: int = Field(
104
- default=3,
105
- description="Maximum number of sources to keep",
106
- )
107
-
108
- def __init__(self):
109
- self.valves = self.Valves()
110
- self.headers = {
111
- "Content-Type": "application/json",
112
- "User-Agent": "LightRag-Tool/1.0",
113
- }
114
-
115
- async def query_lightrag(
116
- self,
117
- query: str,
118
- __event_emitter__: Callable[[dict], Any] = None,
119
- ) -> str:
120
- """
121
- Query the LightRag server and retrieve information.
122
- This function must be called before answering the user question
123
- :params query: The query string to send to the LightRag server.
124
- :return: The response from the LightRag server in Markdown format or raw response.
125
- """
126
- self.status_emitter = StatusEventEmitter(__event_emitter__)
127
- self.message_emitter = MessageEventEmitter(__event_emitter__)
128
-
129
- lightrag_url = self.valves.LIGHTRAG_SERVER_URL
130
- payload = {
131
- "query": query,
132
- "mode": str(self.valves.MODE),
133
- "stream": False,
134
- "only_need_context": self.valves.ONLY_NEED_CONTEXT,
135
- }
136
- await self.status_emitter.emit("Initializing Lightrag query..")
137
-
138
- if self.valves.DEBUG_MODE:
139
- await self.message_emitter.emit(
140
- "### Debug Mode Active\n\nDebugging information will be displayed.\n"
141
- )
142
- await self.message_emitter.emit(
143
- "#### Payload Sent to LightRag Server\n```json\n"
144
- + json.dumps(payload, indent=4)
145
- + "\n```\n"
146
- )
147
-
148
- # Add Bearer Key to headers if provided
149
- if self.valves.KEY:
150
- self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
151
-
152
- try:
153
- await self.status_emitter.emit("Sending request to LightRag server")
154
-
155
- response = requests.post(
156
- lightrag_url, json=payload, headers=self.headers, timeout=120
157
- )
158
- response.raise_for_status()
159
- data = response.json()
160
- await self.status_emitter.emit(
161
- status="complete",
162
- description="LightRag query Succeeded",
163
- done=True,
164
- )
165
-
166
- # Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
167
- if self.valves.ONLY_NEED_CONTEXT:
168
- try:
169
- if self.valves.DEBUG_MODE:
170
- await self.message_emitter.emit(
171
- "#### LightRag Server Response\n```json\n"
172
- + data["response"]
173
- + "\n```\n"
174
- )
175
- except Exception as ex:
176
- if self.valves.DEBUG_MODE:
177
- await self.message_emitter.emit(
178
- "#### Exception\n" + str(ex) + "\n"
179
- )
180
- return f"Exception: {ex}"
181
- return data["response"]
182
- else:
183
- if self.valves.DEBUG_MODE:
184
- await self.message_emitter.emit(
185
- "#### LightRag Server Response\n```json\n"
186
- + data["response"]
187
- + "\n```\n"
188
- )
189
- await self.status_emitter.emit("Lightrag query success")
190
- return data["response"]
191
-
192
- except requests.exceptions.RequestException as e:
193
- await self.status_emitter.emit(
194
- status="error",
195
- description=f"Error during LightRag query: {str(e)}",
196
- done=True,
197
- )
198
- return json.dumps({"error": str(e)})
199
-
200
- def extract_code_blocks(
201
- self, text: str, return_remaining_text: bool = False
202
- ) -> Union[List[dict], Tuple[List[dict], str]]:
203
- """
204
- This function extracts code blocks from a given text and optionally returns the text without code blocks.
205
-
206
- Parameters:
207
- text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
208
- return_remaining_text (bool): If True, also returns the text with code blocks removed.
209
-
210
- Returns:
211
- Union[List[dict], Tuple[List[dict], str]]:
212
- - If return_remaining_text is False: Returns only the list of code block dictionaries
213
- - If return_remaining_text is True: Returns a tuple containing:
214
- * List of code block dictionaries
215
- * String containing the text with all code blocks removed
216
-
217
- Each code block dictionary contains:
218
- - 'index' (int): The index of the code block in the text
219
- - 'file_name' (str): The name of the file extracted from the preceding line, if available
220
- - 'content' (str): The content of the code block
221
- - 'type' (str): The type of the code block
222
- - 'is_complete' (bool): True if the block has a closing tag, False otherwise
223
- """
224
- remaining = text
225
- bloc_index = 0
226
- first_index = 0
227
- indices = []
228
- text_without_blocks = text
229
-
230
- # Find all code block delimiters
231
- while len(remaining) > 0:
232
- try:
233
- index = remaining.index("```")
234
- indices.append(index + first_index)
235
- remaining = remaining[index + 3 :]
236
- first_index += index + 3
237
- bloc_index += 1
238
- except Exception:
239
- if bloc_index % 2 == 1:
240
- index = len(remaining)
241
- indices.append(index)
242
- remaining = ""
243
-
244
- code_blocks = []
245
- is_start = True
246
-
247
- # Process code blocks and build text without blocks if requested
248
- if return_remaining_text:
249
- text_parts = []
250
- last_end = 0
251
-
252
- for index, code_delimiter_position in enumerate(indices):
253
- if is_start:
254
- block_infos = {
255
- "index": len(code_blocks),
256
- "file_name": "",
257
- "section": "",
258
- "content": "",
259
- "type": "",
260
- "is_complete": False,
261
- }
262
-
263
- # Store text before code block if returning remaining text
264
- if return_remaining_text:
265
- text_parts.append(text[last_end:code_delimiter_position].strip())
266
-
267
- # Check the preceding line for file name
268
- preceding_text = text[:code_delimiter_position].strip().splitlines()
269
- if preceding_text:
270
- last_line = preceding_text[-1].strip()
271
- if last_line.startswith("<file_name>") and last_line.endswith(
272
- "</file_name>"
273
- ):
274
- file_name = last_line[
275
- len("<file_name>") : -len("</file_name>")
276
- ].strip()
277
- block_infos["file_name"] = file_name
278
- elif last_line.startswith("## filename:"):
279
- file_name = last_line[len("## filename:") :].strip()
280
- block_infos["file_name"] = file_name
281
- if last_line.startswith("<section>") and last_line.endswith(
282
- "</section>"
283
- ):
284
- section = last_line[
285
- len("<section>") : -len("</section>")
286
- ].strip()
287
- block_infos["section"] = section
288
-
289
- sub_text = text[code_delimiter_position + 3 :]
290
- if len(sub_text) > 0:
291
- try:
292
- find_space = sub_text.index(" ")
293
- except Exception:
294
- find_space = int(1e10)
295
- try:
296
- find_return = sub_text.index("\n")
297
- except Exception:
298
- find_return = int(1e10)
299
- next_index = min(find_return, find_space)
300
- if "{" in sub_text[:next_index]:
301
- next_index = 0
302
- start_pos = next_index
303
-
304
- if code_delimiter_position + 3 < len(text) and text[
305
- code_delimiter_position + 3
306
- ] in ["\n", " ", "\t"]:
307
- block_infos["type"] = "language-specific"
308
- else:
309
- block_infos["type"] = sub_text[:next_index]
310
-
311
- if index + 1 < len(indices):
312
- next_pos = indices[index + 1] - code_delimiter_position
313
- if (
314
- next_pos - 3 < len(sub_text)
315
- and sub_text[next_pos - 3] == "`"
316
- ):
317
- block_infos["content"] = sub_text[
318
- start_pos : next_pos - 3
319
- ].strip()
320
- block_infos["is_complete"] = True
321
- else:
322
- block_infos["content"] = sub_text[
323
- start_pos:next_pos
324
- ].strip()
325
- block_infos["is_complete"] = False
326
-
327
- if return_remaining_text:
328
- last_end = indices[index + 1] + 3
329
- else:
330
- block_infos["content"] = sub_text[start_pos:].strip()
331
- block_infos["is_complete"] = False
332
-
333
- if return_remaining_text:
334
- last_end = len(text)
335
-
336
- code_blocks.append(block_infos)
337
- is_start = False
338
- else:
339
- is_start = True
340
-
341
- if return_remaining_text:
342
- # Add any remaining text after the last code block
343
- if last_end < len(text):
344
- text_parts.append(text[last_end:].strip())
345
- # Join all non-code parts with newlines
346
- text_without_blocks = "\n".join(filter(None, text_parts))
347
- return code_blocks, text_without_blocks
348
-
349
- return code_blocks
350
-
351
- def clean(self, csv_content: str):
352
- lines = csv_content.splitlines()
353
- if lines:
354
- # Remove spaces around headers and ensure no spaces between commas
355
- header = ",".join([col.strip() for col in lines[0].split(",")])
356
- lines[0] = header # Replace the first line with the cleaned header
357
- csv_content = "\n".join(lines)
358
- return csv_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightrag/api/README.md CHANGED
@@ -185,7 +185,8 @@ TiDBVectorDBStorage TiDB
185
  PGVectorStorage Postgres
186
  FaissVectorDBStorage Faiss
187
  QdrantVectorDBStorage Qdrant
188
- OracleVectorDBStorag Oracle
 
189
  ```
190
 
191
  * DOC_STATUS_STORAGE:supported implement-name
 
185
  PGVectorStorage Postgres
186
  FaissVectorDBStorage Faiss
187
  QdrantVectorDBStorage Qdrant
188
+ OracleVectorDBStorage Oracle
189
+ MongoVectorDBStorage MongoDB
190
  ```
191
 
192
  * DOC_STATUS_STORAGE:supported implement-name
lightrag/base.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  from dataclasses import dataclass, field
@@ -5,10 +7,8 @@ from enum import Enum
5
  from typing import (
6
  Any,
7
  Literal,
8
- Optional,
9
  TypedDict,
10
  TypeVar,
11
- Union,
12
  )
13
  import numpy as np
14
  from .utils import EmbeddingFunc
@@ -72,7 +72,7 @@ class QueryParam:
72
  ll_keywords: list[str] = field(default_factory=list)
73
  """List of low-level keywords to refine retrieval focus."""
74
 
75
- conversation_history: list[dict[str, Any]] = field(default_factory=list)
76
  """Stores past conversation history to maintain context.
77
  Format: [{"role": "user/assistant", "content": "message"}].
78
  """
@@ -86,19 +86,15 @@ class StorageNameSpace:
86
  namespace: str
87
  global_config: dict[str, Any]
88
 
89
- async def index_done_callback(self):
90
  """Commit the storage operations after indexing"""
91
  pass
92
 
93
- async def query_done_callback(self):
94
- """Commit the storage operations after querying"""
95
- pass
96
-
97
 
98
  @dataclass
99
  class BaseVectorStorage(StorageNameSpace):
100
  embedding_func: EmbeddingFunc
101
- meta_fields: set = field(default_factory=set)
102
 
103
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
104
  raise NotImplementedError
@@ -109,12 +105,20 @@ class BaseVectorStorage(StorageNameSpace):
109
  """
110
  raise NotImplementedError
111
 
 
 
 
 
 
 
 
 
112
 
113
  @dataclass
114
  class BaseKVStorage(StorageNameSpace):
115
- embedding_func: EmbeddingFunc
116
 
117
- async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
118
  raise NotImplementedError
119
 
120
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -133,50 +137,75 @@ class BaseKVStorage(StorageNameSpace):
133
 
134
  @dataclass
135
  class BaseGraphStorage(StorageNameSpace):
136
- embedding_func: EmbeddingFunc = None
 
137
 
138
  async def has_node(self, node_id: str) -> bool:
139
  raise NotImplementedError
140
 
 
 
141
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
142
  raise NotImplementedError
143
 
 
 
144
  async def node_degree(self, node_id: str) -> int:
145
  raise NotImplementedError
146
 
 
 
147
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
148
  raise NotImplementedError
149
 
150
- async def get_node(self, node_id: str) -> Union[dict, None]:
 
 
151
  raise NotImplementedError
152
 
 
 
153
  async def get_edge(
154
  self, source_node_id: str, target_node_id: str
155
- ) -> Union[dict, None]:
156
  raise NotImplementedError
157
 
158
- async def get_node_edges(
159
- self, source_node_id: str
160
- ) -> Union[list[tuple[str, str]], None]:
161
  raise NotImplementedError
162
 
163
- async def upsert_node(self, node_id: str, node_data: dict[str, str]):
 
 
164
  raise NotImplementedError
165
 
 
 
166
  async def upsert_edge(
167
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
168
- ):
169
  raise NotImplementedError
170
 
171
- async def delete_node(self, node_id: str):
 
 
172
  raise NotImplementedError
173
 
174
- async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
 
 
 
 
175
  raise NotImplementedError("Node embedding is not used in lightrag.")
176
 
 
 
177
  async def get_all_labels(self) -> list[str]:
178
  raise NotImplementedError
179
 
 
 
180
  async def get_knowledge_graph(
181
  self, node_label: str, max_depth: int = 5
182
  ) -> KnowledgeGraph:
@@ -208,9 +237,9 @@ class DocProcessingStatus:
208
  """ISO format timestamp when document was created"""
209
  updated_at: str
210
  """ISO format timestamp when document was last updated"""
211
- chunks_count: Optional[int] = None
212
  """Number of chunks after splitting, used for processing"""
213
- error: Optional[str] = None
214
  """Error message if failed"""
215
  metadata: dict[str, Any] = field(default_factory=dict)
216
  """Additional metadata"""
 
1
+ from __future__ import annotations
2
+
3
  import os
4
  from dotenv import load_dotenv
5
  from dataclasses import dataclass, field
 
7
  from typing import (
8
  Any,
9
  Literal,
 
10
  TypedDict,
11
  TypeVar,
 
12
  )
13
  import numpy as np
14
  from .utils import EmbeddingFunc
 
72
  ll_keywords: list[str] = field(default_factory=list)
73
  """List of low-level keywords to refine retrieval focus."""
74
 
75
+ conversation_history: list[dict[str, str]] = field(default_factory=list)
76
  """Stores past conversation history to maintain context.
77
  Format: [{"role": "user/assistant", "content": "message"}].
78
  """
 
86
  namespace: str
87
  global_config: dict[str, Any]
88
 
89
+ async def index_done_callback(self) -> None:
90
  """Commit the storage operations after indexing"""
91
  pass
92
 
 
 
 
 
93
 
94
  @dataclass
95
  class BaseVectorStorage(StorageNameSpace):
96
  embedding_func: EmbeddingFunc
97
+ meta_fields: set[str] = field(default_factory=set)
98
 
99
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
100
  raise NotImplementedError
 
105
  """
106
  raise NotImplementedError
107
 
108
+ async def delete_entity(self, entity_name: str) -> None:
109
+ """Delete a single entity by its name"""
110
+ raise NotImplementedError
111
+
112
+ async def delete_entity_relation(self, entity_name: str) -> None:
113
+ """Delete relations for a given entity by scanning metadata"""
114
+ raise NotImplementedError
115
+
116
 
117
  @dataclass
118
  class BaseKVStorage(StorageNameSpace):
119
+ embedding_func: EmbeddingFunc | None = None
120
 
121
+ async def get_by_id(self, id: str) -> dict[str, Any] | None:
122
  raise NotImplementedError
123
 
124
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
 
137
 
138
  @dataclass
139
  class BaseGraphStorage(StorageNameSpace):
140
+ embedding_func: EmbeddingFunc | None = None
141
+ """Check if a node exists in the graph."""
142
 
143
  async def has_node(self, node_id: str) -> bool:
144
  raise NotImplementedError
145
 
146
+ """Check if an edge exists in the graph."""
147
+
148
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
149
  raise NotImplementedError
150
 
151
+ """Get the degree of a node."""
152
+
153
  async def node_degree(self, node_id: str) -> int:
154
  raise NotImplementedError
155
 
156
+ """Get the degree of an edge."""
157
+
158
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
159
  raise NotImplementedError
160
 
161
+ """Get a node by its id."""
162
+
163
+ async def get_node(self, node_id: str) -> dict[str, str] | None:
164
  raise NotImplementedError
165
 
166
+ """Get an edge by its source and target node ids."""
167
+
168
  async def get_edge(
169
  self, source_node_id: str, target_node_id: str
170
+ ) -> dict[str, str] | None:
171
  raise NotImplementedError
172
 
173
+ """Get all edges connected to a node."""
174
+
175
+ async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
176
  raise NotImplementedError
177
 
178
+ """Upsert a node into the graph."""
179
+
180
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
181
  raise NotImplementedError
182
 
183
+ """Upsert an edge into the graph."""
184
+
185
  async def upsert_edge(
186
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
187
+ ) -> None:
188
  raise NotImplementedError
189
 
190
+ """Delete a node from the graph."""
191
+
192
+ async def delete_node(self, node_id: str) -> None:
193
  raise NotImplementedError
194
 
195
+ """Embed nodes using an algorithm."""
196
+
197
+ async def embed_nodes(
198
+ self, algorithm: str
199
+ ) -> tuple[np.ndarray[Any, Any], list[str]]:
200
  raise NotImplementedError("Node embedding is not used in lightrag.")
201
 
202
+ """Get all labels in the graph."""
203
+
204
  async def get_all_labels(self) -> list[str]:
205
  raise NotImplementedError
206
 
207
+ """Get a knowledge graph of a node."""
208
+
209
  async def get_knowledge_graph(
210
  self, node_label: str, max_depth: int = 5
211
  ) -> KnowledgeGraph:
 
237
  """ISO format timestamp when document was created"""
238
  updated_at: str
239
  """ISO format timestamp when document was last updated"""
240
+ chunks_count: int | None = None
241
  """Number of chunks after splitting, used for processing"""
242
+ error: str | None = None
243
  """Error message if failed"""
244
  metadata: dict[str, Any] = field(default_factory=dict)
245
  """Additional metadata"""
lightrag/exceptions.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import httpx
2
  from typing import Literal
3
 
 
1
+ from __future__ import annotations
2
+
3
  import httpx
4
  from typing import Literal
5
 
lightrag/kg/chroma_impl.py CHANGED
@@ -2,7 +2,7 @@ import asyncio
2
  from dataclasses import dataclass
3
  from typing import Union
4
  import numpy as np
5
- from chromadb import HttpClient
6
  from chromadb.config import Settings
7
  from lightrag.base import BaseVectorStorage
8
  from lightrag.utils import logger
@@ -49,31 +49,43 @@ class ChromaVectorDBStorage(BaseVectorStorage):
49
  **user_collection_settings,
50
  }
51
 
52
- auth_provider = config.get(
53
- "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
54
- )
55
- auth_credentials = config.get("auth_token", "secret-token")
56
- headers = {}
57
-
58
- if "token_authn" in auth_provider:
59
- headers = {
60
- config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
61
- }
62
- elif "basic_authn" in auth_provider:
63
- auth_credentials = config.get("auth_credentials", "admin:admin")
64
-
65
- self._client = HttpClient(
66
- host=config.get("host", "localhost"),
67
- port=config.get("port", 8000),
68
- headers=headers,
69
- settings=Settings(
70
- chroma_api_impl="rest",
71
- chroma_client_auth_provider=auth_provider,
72
- chroma_client_auth_credentials=auth_credentials,
73
- allow_reset=True,
74
- anonymized_telemetry=False,
75
- ),
76
- )
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  self._collection = self._client.get_or_create_collection(
79
  name=self.namespace,
@@ -144,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
144
  embedding = await self.embedding_func([query])
145
 
146
  results = self._collection.query(
147
- query_embeddings=embedding.tolist(),
 
 
148
  n_results=top_k * 2, # Request more results to allow for filtering
149
  include=["metadatas", "distances", "documents"],
150
  )
 
2
  from dataclasses import dataclass
3
  from typing import Union
4
  import numpy as np
5
+ from chromadb import HttpClient, PersistentClient
6
  from chromadb.config import Settings
7
  from lightrag.base import BaseVectorStorage
8
  from lightrag.utils import logger
 
49
  **user_collection_settings,
50
  }
51
 
52
+ local_path = config.get("local_path", None)
53
+ if local_path:
54
+ self._client = PersistentClient(
55
+ path=local_path,
56
+ settings=Settings(
57
+ allow_reset=True,
58
+ anonymized_telemetry=False,
59
+ ),
60
+ )
61
+ else:
62
+ auth_provider = config.get(
63
+ "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
64
+ )
65
+ auth_credentials = config.get("auth_token", "secret-token")
66
+ headers = {}
67
+
68
+ if "token_authn" in auth_provider:
69
+ headers = {
70
+ config.get(
71
+ "auth_header_name", "X-Chroma-Token"
72
+ ): auth_credentials
73
+ }
74
+ elif "basic_authn" in auth_provider:
75
+ auth_credentials = config.get("auth_credentials", "admin:admin")
76
+
77
+ self._client = HttpClient(
78
+ host=config.get("host", "localhost"),
79
+ port=config.get("port", 8000),
80
+ headers=headers,
81
+ settings=Settings(
82
+ chroma_api_impl="rest",
83
+ chroma_client_auth_provider=auth_provider,
84
+ chroma_client_auth_credentials=auth_credentials,
85
+ allow_reset=True,
86
+ anonymized_telemetry=False,
87
+ ),
88
+ )
89
 
90
  self._collection = self._client.get_or_create_collection(
91
  name=self.namespace,
 
156
  embedding = await self.embedding_func([query])
157
 
158
  results = self._collection.query(
159
+ query_embeddings=embedding.tolist()
160
+ if not isinstance(embedding, list)
161
+ else embedding,
162
  n_results=top_k * 2, # Request more results to allow for filtering
163
  include=["metadatas", "distances", "documents"],
164
  )
lightrag/kg/faiss_impl.py CHANGED
@@ -219,7 +219,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
219
  logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
220
  await self.delete([entity_id])
221
 
222
- async def delete_entity_relation(self, entity_name: str):
223
  """
224
  Delete relations for a given entity by scanning metadata.
225
  """
 
219
  logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
220
  await self.delete([entity_id])
221
 
222
+ async def delete_entity_relation(self, entity_name: str) -> None:
223
  """
224
  Delete relations for a given entity by scanning metadata.
225
  """
lightrag/kg/json_kv_impl.py CHANGED
@@ -47,3 +47,8 @@ class JsonKVStorage(BaseKVStorage):
47
 
48
  async def drop(self) -> None:
49
  self._data = {}
 
 
 
 
 
 
47
 
48
  async def drop(self) -> None:
49
  self._data = {}
50
+
51
+ async def delete(self, ids: list[str]) -> None:
52
+ for doc_id in ids:
53
+ self._data.pop(doc_id, None)
54
+ await self.index_done_callback()
lightrag/kg/mongo_impl.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import pipmaster as pm
5
  import configparser
6
  from tqdm.asyncio import tqdm as tqdm_async
 
7
 
8
  if not pm.is_installed("pymongo"):
9
  pm.install("pymongo")
@@ -14,16 +15,20 @@ if not pm.is_installed("motor"):
14
  from typing import Any, List, Tuple, Union
15
  from motor.motor_asyncio import AsyncIOMotorClient
16
  from pymongo import MongoClient
 
 
17
 
18
  from ..base import (
19
  BaseGraphStorage,
20
  BaseKVStorage,
 
21
  DocProcessingStatus,
22
  DocStatus,
23
  DocStatusStorage,
24
  )
25
  from ..namespace import NameSpace, is_namespace
26
  from ..utils import logger
 
27
 
28
 
29
  config = configparser.ConfigParser()
@@ -33,56 +38,66 @@ config.read("config.ini", "utf-8")
33
  @dataclass
34
  class MongoKVStorage(BaseKVStorage):
35
  def __post_init__(self):
36
- client = MongoClient(
37
- os.environ.get(
38
- "MONGO_URI",
39
- config.get(
40
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
41
- ),
42
- )
43
  )
 
44
  database = client.get_database(
45
  os.environ.get(
46
  "MONGO_DATABASE",
47
  config.get("mongodb", "database", fallback="LightRAG"),
48
  )
49
  )
50
- self._data = database.get_collection(self.namespace)
51
- logger.info(f"Use MongoDB as KV {self.namespace}")
 
 
 
 
 
 
52
 
53
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
54
- return self._data.find_one({"_id": id})
55
 
56
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
57
- return list(self._data.find({"_id": {"$in": ids}}))
 
58
 
59
  async def filter_keys(self, data: set[str]) -> set[str]:
60
- existing_ids = [
61
- str(x["_id"])
62
- for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
63
- ]
64
- return set([s for s in data if s not in existing_ids])
65
 
66
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
67
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
 
68
  for mode, items in data.items():
69
- for k, v in tqdm_async(items.items(), desc="Upserting"):
70
  key = f"{mode}_{k}"
71
- result = self._data.update_one(
72
- {"_id": key}, {"$setOnInsert": v}, upsert=True
 
 
 
73
  )
74
- if result.upserted_id:
75
- logger.debug(f"\nInserted new document with key: {key}")
76
- data[mode][k]["_id"] = key
77
  else:
78
- for k, v in tqdm_async(data.items(), desc="Upserting"):
79
- self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
80
  data[k]["_id"] = k
 
 
 
 
81
 
82
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
83
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
84
  res = {}
85
- v = self._data.find_one({"_id": mode + "_" + id})
86
  if v:
87
  res[id] = v
88
  logger.debug(f"llm_response_cache find one by:{id}")
@@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage):
100
  @dataclass
101
  class MongoDocStatusStorage(DocStatusStorage):
102
  def __post_init__(self):
103
- client = MongoClient(
104
- os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
 
 
 
 
 
 
 
 
 
 
105
  )
106
- database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
107
- self._data = database.get_collection(self.namespace)
108
- logger.info(f"Use MongoDB as doc status {self.namespace}")
 
 
 
 
 
109
 
110
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
111
- return self._data.find_one({"_id": id})
112
 
113
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
114
- return list(self._data.find({"_id": {"$in": ids}}))
 
115
 
116
  async def filter_keys(self, data: set[str]) -> set[str]:
117
- existing_ids = [
118
- str(x["_id"])
119
- for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
120
- ]
121
- return set([s for s in data if s not in existing_ids])
122
 
123
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
124
  for k, v in data.items():
125
- self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
126
  data[k]["_id"] = k
 
 
 
 
127
 
128
  async def drop(self) -> None:
129
  """Drop the collection"""
@@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage):
132
  async def get_status_counts(self) -> dict[str, int]:
133
  """Get counts of documents in each status"""
134
  pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
135
- result = list(self._data.aggregate(pipeline))
 
136
  counts = {}
137
  for doc in result:
138
  counts[doc["_id"]] = doc["count"]
@@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage):
142
  self, status: DocStatus
143
  ) -> dict[str, DocProcessingStatus]:
144
  """Get all documents by status"""
145
- result = list(self._data.find({"status": status.value}))
 
146
  return {
147
  doc["_id"]: DocProcessingStatus(
148
  content=doc["content"],
@@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage):
185
  global_config=global_config,
186
  embedding_func=embedding_func,
187
  )
188
- self.client = AsyncIOMotorClient(
189
- os.environ.get(
190
- "MONGO_URI",
191
- config.get(
192
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
193
- ),
194
- )
195
  )
196
- self.db = self.client[
 
197
  os.environ.get(
198
  "MONGO_DATABASE",
199
- mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
200
- )
201
- ]
202
- self.collection = self.db[
203
- os.environ.get(
204
- "MONGO_KG_COLLECTION",
205
- config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
206
  )
207
- ]
 
 
 
 
 
 
 
 
208
 
209
  #
210
  # -------------------------------------------------------------------------
@@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage):
451
  self, source_node_id: str
452
  ) -> Union[List[Tuple[str, str]], None]:
453
  """
454
- Return a list of (target_id, relation) for direct edges from source_node_id.
455
  Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
456
  """
457
  pipeline = [
@@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage):
475
  return None
476
 
477
  edges = result[0].get("edges", [])
478
- return [(e["target"], e["relation"]) for e in edges]
479
 
480
  #
481
  # -------------------------------------------------------------------------
@@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage):
522
 
523
  async def delete_node(self, node_id: str):
524
  """
525
- 1) Remove nodes doc entirely.
526
  2) Remove inbound edges from any doc that references node_id.
527
  """
528
  # Remove inbound edges from all other docs
@@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage):
542
  Placeholder for demonstration, raises NotImplementedError.
543
  """
544
  raise NotImplementedError("Node embedding is not used in lightrag.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import pipmaster as pm
5
  import configparser
6
  from tqdm.asyncio import tqdm as tqdm_async
7
+ import asyncio
8
 
9
  if not pm.is_installed("pymongo"):
10
  pm.install("pymongo")
 
15
  from typing import Any, List, Tuple, Union
16
  from motor.motor_asyncio import AsyncIOMotorClient
17
  from pymongo import MongoClient
18
+ from pymongo.operations import SearchIndexModel
19
+ from pymongo.errors import PyMongoError
20
 
21
  from ..base import (
22
  BaseGraphStorage,
23
  BaseKVStorage,
24
+ BaseVectorStorage,
25
  DocProcessingStatus,
26
  DocStatus,
27
  DocStatusStorage,
28
  )
29
  from ..namespace import NameSpace, is_namespace
30
  from ..utils import logger
31
+ from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
32
 
33
 
34
  config = configparser.ConfigParser()
 
38
  @dataclass
39
  class MongoKVStorage(BaseKVStorage):
40
  def __post_init__(self):
41
+ uri = os.environ.get(
42
+ "MONGO_URI",
43
+ config.get(
44
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
45
+ ),
 
 
46
  )
47
+ client = AsyncIOMotorClient(uri)
48
  database = client.get_database(
49
  os.environ.get(
50
  "MONGO_DATABASE",
51
  config.get("mongodb", "database", fallback="LightRAG"),
52
  )
53
  )
54
+
55
+ self._collection_name = self.namespace
56
+
57
+ self._data = database.get_collection(self._collection_name)
58
+ logger.debug(f"Use MongoDB as KV {self._collection_name}")
59
+
60
+ # Ensure collection exists
61
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
62
 
63
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
64
+ return await self._data.find_one({"_id": id})
65
 
66
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
67
+ cursor = self._data.find({"_id": {"$in": ids}})
68
+ return await cursor.to_list()
69
 
70
  async def filter_keys(self, data: set[str]) -> set[str]:
71
+ cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
72
+ existing_ids = {str(x["_id"]) async for x in cursor}
73
+ return data - existing_ids
 
 
74
 
75
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
76
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
77
+ update_tasks = []
78
  for mode, items in data.items():
79
+ for k, v in items.items():
80
  key = f"{mode}_{k}"
81
+ data[mode][k]["_id"] = f"{mode}_{k}"
82
+ update_tasks.append(
83
+ self._data.update_one(
84
+ {"_id": key}, {"$setOnInsert": v}, upsert=True
85
+ )
86
  )
87
+ await asyncio.gather(*update_tasks)
 
 
88
  else:
89
+ update_tasks = []
90
+ for k, v in data.items():
91
  data[k]["_id"] = k
92
+ update_tasks.append(
93
+ self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
94
+ )
95
+ await asyncio.gather(*update_tasks)
96
 
97
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
98
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
99
  res = {}
100
+ v = await self._data.find_one({"_id": mode + "_" + id})
101
  if v:
102
  res[id] = v
103
  logger.debug(f"llm_response_cache find one by:{id}")
 
115
  @dataclass
116
  class MongoDocStatusStorage(DocStatusStorage):
117
  def __post_init__(self):
118
+ uri = os.environ.get(
119
+ "MONGO_URI",
120
+ config.get(
121
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
122
+ ),
123
+ )
124
+ client = AsyncIOMotorClient(uri)
125
+ database = client.get_database(
126
+ os.environ.get(
127
+ "MONGO_DATABASE",
128
+ config.get("mongodb", "database", fallback="LightRAG"),
129
+ )
130
  )
131
+
132
+ self._collection_name = self.namespace
133
+ self._data = database.get_collection(self._collection_name)
134
+
135
+ logger.debug(f"Use MongoDB as doc status {self._collection_name}")
136
+
137
+ # Ensure collection exists
138
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
139
 
140
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
141
+ return await self._data.find_one({"_id": id})
142
 
143
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
144
+ cursor = self._data.find({"_id": {"$in": ids}})
145
+ return await cursor.to_list()
146
 
147
  async def filter_keys(self, data: set[str]) -> set[str]:
148
+ cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
149
+ existing_ids = {str(x["_id"]) async for x in cursor}
150
+ return data - existing_ids
 
 
151
 
152
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
153
+ update_tasks = []
154
  for k, v in data.items():
 
155
  data[k]["_id"] = k
156
+ update_tasks.append(
157
+ self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
158
+ )
159
+ await asyncio.gather(*update_tasks)
160
 
161
  async def drop(self) -> None:
162
  """Drop the collection"""
 
165
  async def get_status_counts(self) -> dict[str, int]:
166
  """Get counts of documents in each status"""
167
  pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
168
+ cursor = self._data.aggregate(pipeline)
169
+ result = await cursor.to_list()
170
  counts = {}
171
  for doc in result:
172
  counts[doc["_id"]] = doc["count"]
 
176
  self, status: DocStatus
177
  ) -> dict[str, DocProcessingStatus]:
178
  """Get all documents by status"""
179
+ cursor = self._data.find({"status": status.value})
180
+ result = await cursor.to_list()
181
  return {
182
  doc["_id"]: DocProcessingStatus(
183
  content=doc["content"],
 
220
  global_config=global_config,
221
  embedding_func=embedding_func,
222
  )
223
+ uri = os.environ.get(
224
+ "MONGO_URI",
225
+ config.get(
226
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
227
+ ),
 
 
228
  )
229
+ client = AsyncIOMotorClient(uri)
230
+ database = client.get_database(
231
  os.environ.get(
232
  "MONGO_DATABASE",
233
+ config.get("mongodb", "database", fallback="LightRAG"),
 
 
 
 
 
 
234
  )
235
+ )
236
+
237
+ self._collection_name = self.namespace
238
+ self.collection = database.get_collection(self._collection_name)
239
+
240
+ logger.debug(f"Use MongoDB as KG {self._collection_name}")
241
+
242
+ # Ensure collection exists
243
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
244
 
245
  #
246
  # -------------------------------------------------------------------------
 
487
  self, source_node_id: str
488
  ) -> Union[List[Tuple[str, str]], None]:
489
  """
490
+ Return a list of (source_id, target_id) for direct edges from source_node_id.
491
  Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
492
  """
493
  pipeline = [
 
511
  return None
512
 
513
  edges = result[0].get("edges", [])
514
+ return [(source_node_id, e["target"]) for e in edges]
515
 
516
  #
517
  # -------------------------------------------------------------------------
 
558
 
559
  async def delete_node(self, node_id: str):
560
  """
561
+ 1) Remove node's doc entirely.
562
  2) Remove inbound edges from any doc that references node_id.
563
  """
564
  # Remove inbound edges from all other docs
 
578
  Placeholder for demonstration, raises NotImplementedError.
579
  """
580
  raise NotImplementedError("Node embedding is not used in lightrag.")
581
+
582
+ #
583
+ # -------------------------------------------------------------------------
584
+ # QUERY
585
+ # -------------------------------------------------------------------------
586
+ #
587
+
588
+ async def get_all_labels(self) -> list[str]:
589
+ """
590
+ Get all existing node _id in the database
591
+ Returns:
592
+ [id1, id2, ...] # Alphabetically sorted id list
593
+ """
594
+ # Use MongoDB's distinct and aggregation to get all unique labels
595
+ pipeline = [
596
+ {"$group": {"_id": "$_id"}}, # Group by _id
597
+ {"$sort": {"_id": 1}}, # Sort alphabetically
598
+ ]
599
+
600
+ cursor = self.collection.aggregate(pipeline)
601
+ labels = []
602
+ async for doc in cursor:
603
+ labels.append(doc["_id"])
604
+ return labels
605
+
606
+ async def get_knowledge_graph(
607
+ self, node_label: str, max_depth: int = 5
608
+ ) -> KnowledgeGraph:
609
+ """
610
+ Get complete connected subgraph for specified node (including the starting node itself)
611
+
612
+ Args:
613
+ node_label: Label of the nodes to start from
614
+ max_depth: Maximum depth of traversal (default: 5)
615
+
616
+ Returns:
617
+ KnowledgeGraph object containing nodes and edges of the subgraph
618
+ """
619
+ label = node_label
620
+ result = KnowledgeGraph()
621
+ seen_nodes = set()
622
+ seen_edges = set()
623
+
624
+ try:
625
+ if label == "*":
626
+ # Get all nodes and edges
627
+ async for node_doc in self.collection.find({}):
628
+ node_id = str(node_doc["_id"])
629
+ if node_id not in seen_nodes:
630
+ result.nodes.append(
631
+ KnowledgeGraphNode(
632
+ id=node_id,
633
+ labels=[node_doc.get("_id")],
634
+ properties={
635
+ k: v
636
+ for k, v in node_doc.items()
637
+ if k not in ["_id", "edges"]
638
+ },
639
+ )
640
+ )
641
+ seen_nodes.add(node_id)
642
+
643
+ # Process edges
644
+ for edge in node_doc.get("edges", []):
645
+ edge_id = f"{node_id}-{edge['target']}"
646
+ if edge_id not in seen_edges:
647
+ result.edges.append(
648
+ KnowledgeGraphEdge(
649
+ id=edge_id,
650
+ type=edge.get("relation", ""),
651
+ source=node_id,
652
+ target=edge["target"],
653
+ properties={
654
+ k: v
655
+ for k, v in edge.items()
656
+ if k not in ["target", "relation"]
657
+ },
658
+ )
659
+ )
660
+ seen_edges.add(edge_id)
661
+ else:
662
+ # Verify if starting node exists
663
+ start_nodes = self.collection.find({"_id": label})
664
+ start_nodes_exist = await start_nodes.to_list(length=1)
665
+ if not start_nodes_exist:
666
+ logger.warning(f"Starting node with label {label} does not exist!")
667
+ return result
668
+
669
+ # Use $graphLookup for traversal
670
+ pipeline = [
671
+ {
672
+ "$match": {"_id": label}
673
+ }, # Start with nodes having the specified label
674
+ {
675
+ "$graphLookup": {
676
+ "from": self._collection_name,
677
+ "startWith": "$edges.target",
678
+ "connectFromField": "edges.target",
679
+ "connectToField": "_id",
680
+ "maxDepth": max_depth,
681
+ "depthField": "depth",
682
+ "as": "connected_nodes",
683
+ }
684
+ },
685
+ ]
686
+
687
+ async for doc in self.collection.aggregate(pipeline):
688
+ # Add the start node
689
+ node_id = str(doc["_id"])
690
+ if node_id not in seen_nodes:
691
+ result.nodes.append(
692
+ KnowledgeGraphNode(
693
+ id=node_id,
694
+ labels=[
695
+ doc.get(
696
+ "_id",
697
+ )
698
+ ],
699
+ properties={
700
+ k: v
701
+ for k, v in doc.items()
702
+ if k
703
+ not in [
704
+ "_id",
705
+ "edges",
706
+ "connected_nodes",
707
+ "depth",
708
+ ]
709
+ },
710
+ )
711
+ )
712
+ seen_nodes.add(node_id)
713
+
714
+ # Add edges from start node
715
+ for edge in doc.get("edges", []):
716
+ edge_id = f"{node_id}-{edge['target']}"
717
+ if edge_id not in seen_edges:
718
+ result.edges.append(
719
+ KnowledgeGraphEdge(
720
+ id=edge_id,
721
+ type=edge.get("relation", ""),
722
+ source=node_id,
723
+ target=edge["target"],
724
+ properties={
725
+ k: v
726
+ for k, v in edge.items()
727
+ if k not in ["target", "relation"]
728
+ },
729
+ )
730
+ )
731
+ seen_edges.add(edge_id)
732
+
733
+ # Add connected nodes and their edges
734
+ for connected in doc.get("connected_nodes", []):
735
+ node_id = str(connected["_id"])
736
+ if node_id not in seen_nodes:
737
+ result.nodes.append(
738
+ KnowledgeGraphNode(
739
+ id=node_id,
740
+ labels=[connected.get("_id")],
741
+ properties={
742
+ k: v
743
+ for k, v in connected.items()
744
+ if k not in ["_id", "edges", "depth"]
745
+ },
746
+ )
747
+ )
748
+ seen_nodes.add(node_id)
749
+
750
+ # Add edges from connected nodes
751
+ for edge in connected.get("edges", []):
752
+ edge_id = f"{node_id}-{edge['target']}"
753
+ if edge_id not in seen_edges:
754
+ result.edges.append(
755
+ KnowledgeGraphEdge(
756
+ id=edge_id,
757
+ type=edge.get("relation", ""),
758
+ source=node_id,
759
+ target=edge["target"],
760
+ properties={
761
+ k: v
762
+ for k, v in edge.items()
763
+ if k not in ["target", "relation"]
764
+ },
765
+ )
766
+ )
767
+ seen_edges.add(edge_id)
768
+
769
+ logger.info(
770
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
771
+ )
772
+
773
+ except PyMongoError as e:
774
+ logger.error(f"MongoDB query failed: {str(e)}")
775
+
776
+ return result
777
+
778
+
779
+ @dataclass
780
+ class MongoVectorDBStorage(BaseVectorStorage):
781
+ cosine_better_than_threshold: float = None
782
+
783
+ def __post_init__(self):
784
+ kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
785
+ cosine_threshold = kwargs.get("cosine_better_than_threshold")
786
+ if cosine_threshold is None:
787
+ raise ValueError(
788
+ "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
789
+ )
790
+ self.cosine_better_than_threshold = cosine_threshold
791
+
792
+ uri = os.environ.get(
793
+ "MONGO_URI",
794
+ config.get(
795
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
796
+ ),
797
+ )
798
+ client = AsyncIOMotorClient(uri)
799
+ database = client.get_database(
800
+ os.environ.get(
801
+ "MONGO_DATABASE",
802
+ config.get("mongodb", "database", fallback="LightRAG"),
803
+ )
804
+ )
805
+
806
+ self._collection_name = self.namespace
807
+ self._data = database.get_collection(self._collection_name)
808
+ self._max_batch_size = self.global_config["embedding_batch_num"]
809
+
810
+ logger.debug(f"Use MongoDB as VDB {self._collection_name}")
811
+
812
+ # Ensure collection exists
813
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
814
+
815
+ # Ensure vector index exists
816
+ self.create_vector_index(uri, database.name, self._collection_name)
817
+
818
+ def create_vector_index(self, uri: str, database_name: str, collection_name: str):
819
+ """Creates an Atlas Vector Search index."""
820
+ client = MongoClient(uri)
821
+ collection = client.get_database(database_name).get_collection(
822
+ self._collection_name
823
+ )
824
+
825
+ try:
826
+ search_index_model = SearchIndexModel(
827
+ definition={
828
+ "fields": [
829
+ {
830
+ "type": "vector",
831
+ "numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions
832
+ "path": "vector",
833
+ "similarity": "cosine", # Options: euclidean, cosine, dotProduct
834
+ }
835
+ ]
836
+ },
837
+ name="vector_knn_index",
838
+ type="vectorSearch",
839
+ )
840
+
841
+ collection.create_search_index(search_index_model)
842
+ logger.info("Vector index created successfully.")
843
+
844
+ except PyMongoError as _:
845
+ logger.debug("vector index already exist")
846
+
847
+ async def upsert(self, data: dict[str, dict]):
848
+ logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
849
+ if not data:
850
+ logger.warning("You are inserting an empty data set to vector DB")
851
+ return []
852
+
853
+ list_data = [
854
+ {
855
+ "_id": k,
856
+ **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
857
+ }
858
+ for k, v in data.items()
859
+ ]
860
+ contents = [v["content"] for v in data.values()]
861
+ batches = [
862
+ contents[i : i + self._max_batch_size]
863
+ for i in range(0, len(contents), self._max_batch_size)
864
+ ]
865
+
866
+ async def wrapped_task(batch):
867
+ result = await self.embedding_func(batch)
868
+ pbar.update(1)
869
+ return result
870
+
871
+ embedding_tasks = [wrapped_task(batch) for batch in batches]
872
+ pbar = tqdm_async(
873
+ total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
874
+ )
875
+ embeddings_list = await asyncio.gather(*embedding_tasks)
876
+
877
+ embeddings = np.concatenate(embeddings_list)
878
+ for i, d in enumerate(list_data):
879
+ d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist()
880
+
881
+ update_tasks = []
882
+ for doc in list_data:
883
+ update_tasks.append(
884
+ self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
885
+ )
886
+ await asyncio.gather(*update_tasks)
887
+
888
+ return list_data
889
+
890
+ async def query(self, query, top_k=5):
891
+ """Queries the vector database using Atlas Vector Search."""
892
+ # Generate the embedding
893
+ embedding = await self.embedding_func([query])
894
+
895
+ # Convert numpy array to a list to ensure compatibility with MongoDB
896
+ query_vector = embedding[0].tolist()
897
+
898
+ # Define the aggregation pipeline with the converted query vector
899
+ pipeline = [
900
+ {
901
+ "$vectorSearch": {
902
+ "index": "vector_knn_index", # Ensure this matches the created index name
903
+ "path": "vector",
904
+ "queryVector": query_vector,
905
+ "numCandidates": 100, # Adjust for performance
906
+ "limit": top_k,
907
+ }
908
+ },
909
+ {"$addFields": {"score": {"$meta": "vectorSearchScore"}}},
910
+ {"$match": {"score": {"$gte": self.cosine_better_than_threshold}}},
911
+ {"$project": {"vector": 0}},
912
+ ]
913
+
914
+ # Execute the aggregation pipeline
915
+ cursor = self._data.aggregate(pipeline)
916
+ results = await cursor.to_list()
917
+
918
+ # Format and return the results
919
+ return [
920
+ {**doc, "id": doc["_id"], "distance": doc.get("score", None)}
921
+ for doc in results
922
+ ]
923
+
924
+
925
+ def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
926
+ """Check if the collection exists. if not, create it."""
927
+ client = MongoClient(uri)
928
+ database = client.get_database(database_name)
929
+
930
+ collection_names = database.list_collection_names()
931
+
932
+ if collection_name not in collection_names:
933
+ database.create_collection(collection_name)
934
+ logger.info(f"Created collection: {collection_name}")
935
+ else:
936
+ logger.debug(f"Collection '{collection_name}' already exists.")
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -191,7 +191,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
191
  except Exception as e:
192
  logger.error(f"Error deleting entity {entity_name}: {e}")
193
 
194
- async def delete_entity_relation(self, entity_name: str):
195
  try:
196
  relations = [
197
  dp
 
191
  except Exception as e:
192
  logger.error(f"Error deleting entity {entity_name}: {e}")
193
 
194
+ async def delete_entity_relation(self, entity_name: str) -> None:
195
  try:
196
  relations = [
197
  dp
lightrag/kg/neo4j_impl.py CHANGED
@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
143
  async def index_done_callback(self):
144
  print("KG successfully indexed.")
145
 
146
- async def has_node(self, node_id: str) -> bool:
147
- entity_name_label = node_id.strip('"')
 
 
 
 
 
 
 
 
 
148
 
 
 
 
 
 
 
 
 
 
149
  async with self._driver.session(database=self._DATABASE) as session:
150
  query = (
151
  f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
174
  return single_result["edgeExists"]
175
 
176
  async def get_node(self, node_id: str) -> Union[dict, None]:
 
 
 
 
 
 
 
 
 
177
  async with self._driver.session(database=self._DATABASE) as session:
178
- entity_name_label = node_id.strip('"')
179
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
180
  result = await session.run(query)
181
  record = await result.single()
@@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage):
226
  async def get_edge(
227
  self, source_node_id: str, target_node_id: str
228
  ) -> Union[dict, None]:
229
- entity_name_label_source = source_node_id.strip('"')
230
- entity_name_label_target = target_node_id.strip('"')
231
- """
232
- Find all edges between nodes of two given labels
233
 
234
  Args:
235
- source_node_label (str): Label of the source nodes
236
- target_node_label (str): Label of the target nodes
237
 
238
  Returns:
239
- list: List of all relationships/edges found
 
240
  """
241
- async with self._driver.session(database=self._DATABASE) as session:
242
- query = f"""
243
- MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
244
- RETURN properties(r) as edge_properties
245
- LIMIT 1
246
- """.format(
247
- entity_name_label_source=entity_name_label_source,
248
- entity_name_label_target=entity_name_label_target,
249
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- result = await session.run(query)
252
- record = await result.single()
253
- if record:
254
- result = dict(record["edge_properties"])
255
  logger.debug(
256
- f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
257
  )
258
- return result
259
- else:
260
- return None
 
 
 
 
 
 
261
 
262
  async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
263
  node_label = source_node_id.strip('"')
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
310
  node_id: The unique identifier for the node (used as label)
311
  node_data: Dictionary of node properties
312
  """
313
- label = node_id.strip('"')
314
  properties = node_data
315
 
316
  async def _do_upsert(tx: AsyncManagedTransaction):
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
338
  neo4jExceptions.ServiceUnavailable,
339
  neo4jExceptions.TransientError,
340
  neo4jExceptions.WriteServiceUnavailable,
 
341
  )
342
  ),
343
  )
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
352
  target_node_id (str): Label of the target node (used as identifier)
353
  edge_data (dict): Dictionary of properties to set on the edge
354
  """
355
- source_node_label = source_node_id.strip('"')
356
- target_node_label = target_node_id.strip('"')
357
  edge_properties = edge_data
358
 
359
  async def _do_upsert_edge(tx: AsyncManagedTransaction):
360
  query = f"""
361
- MATCH (source:`{source_node_label}`)
362
  WITH source
363
- MATCH (target:`{target_node_label}`)
364
  MERGE (source)-[r:DIRECTED]->(target)
365
  SET r += $properties
366
  RETURN r
367
  """
368
- await tx.run(query, properties=edge_properties)
 
369
  logger.debug(
370
- f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
371
  )
372
 
373
  try:
 
143
  async def index_done_callback(self):
144
  print("KG successfully indexed.")
145
 
146
+ async def _label_exists(self, label: str) -> bool:
147
+ """Check if a label exists in the Neo4j database."""
148
+ query = "CALL db.labels() YIELD label RETURN label"
149
+ try:
150
+ async with self._driver.session(database=self._DATABASE) as session:
151
+ result = await session.run(query)
152
+ labels = [record["label"] for record in await result.data()]
153
+ return label in labels
154
+ except Exception as e:
155
+ logger.error(f"Error checking label existence: {e}")
156
+ return False
157
 
158
+ async def _ensure_label(self, label: str) -> str:
159
+ """Ensure a label exists by validating it."""
160
+ clean_label = label.strip('"')
161
+ if not await self._label_exists(clean_label):
162
+ logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
163
+ return clean_label
164
+
165
+ async def has_node(self, node_id: str) -> bool:
166
+ entity_name_label = await self._ensure_label(node_id)
167
  async with self._driver.session(database=self._DATABASE) as session:
168
  query = (
169
  f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
 
192
  return single_result["edgeExists"]
193
 
194
  async def get_node(self, node_id: str) -> Union[dict, None]:
195
+ """Get node by its label identifier.
196
+
197
+ Args:
198
+ node_id: The node label to look up
199
+
200
+ Returns:
201
+ dict: Node properties if found
202
+ None: If node not found
203
+ """
204
  async with self._driver.session(database=self._DATABASE) as session:
205
+ entity_name_label = await self._ensure_label(node_id)
206
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
207
  result = await session.run(query)
208
  record = await result.single()
 
253
  async def get_edge(
254
  self, source_node_id: str, target_node_id: str
255
  ) -> Union[dict, None]:
256
+ """Find edge between two nodes identified by their labels.
 
 
 
257
 
258
  Args:
259
+ source_node_id (str): Label of the source node
260
+ target_node_id (str): Label of the target node
261
 
262
  Returns:
263
+ dict: Edge properties if found, with at least {"weight": 0.0}
264
+ None: If error occurs
265
  """
266
+ try:
267
+ entity_name_label_source = source_node_id.strip('"')
268
+ entity_name_label_target = target_node_id.strip('"')
269
+
270
+ async with self._driver.session(database=self._DATABASE) as session:
271
+ query = f"""
272
+ MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
273
+ RETURN properties(r) as edge_properties
274
+ LIMIT 1
275
+ """.format(
276
+ entity_name_label_source=entity_name_label_source,
277
+ entity_name_label_target=entity_name_label_target,
278
+ )
279
+
280
+ result = await session.run(query)
281
+ record = await result.single()
282
+ if record and "edge_properties" in record:
283
+ try:
284
+ result = dict(record["edge_properties"])
285
+ # Ensure required keys exist with defaults
286
+ required_keys = {
287
+ "weight": 0.0,
288
+ "source_id": None,
289
+ "target_id": None,
290
+ }
291
+ for key, default_value in required_keys.items():
292
+ if key not in result:
293
+ result[key] = default_value
294
+ logger.warning(
295
+ f"Edge between {entity_name_label_source} and {entity_name_label_target} "
296
+ f"missing {key}, using default: {default_value}"
297
+ )
298
+
299
+ logger.debug(
300
+ f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
301
+ )
302
+ return result
303
+ except (KeyError, TypeError, ValueError) as e:
304
+ logger.error(
305
+ f"Error processing edge properties between {entity_name_label_source} "
306
+ f"and {entity_name_label_target}: {str(e)}"
307
+ )
308
+ # Return default edge properties on error
309
+ return {"weight": 0.0, "source_id": None, "target_id": None}
310
 
 
 
 
 
311
  logger.debug(
312
+ f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
313
  )
314
+ # Return default edge properties when no edge found
315
+ return {"weight": 0.0, "source_id": None, "target_id": None}
316
+
317
+ except Exception as e:
318
+ logger.error(
319
+ f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
320
+ )
321
+ # Return default edge properties on error
322
+ return {"weight": 0.0, "source_id": None, "target_id": None}
323
 
324
  async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
325
  node_label = source_node_id.strip('"')
 
372
  node_id: The unique identifier for the node (used as label)
373
  node_data: Dictionary of node properties
374
  """
375
+ label = await self._ensure_label(node_id)
376
  properties = node_data
377
 
378
  async def _do_upsert(tx: AsyncManagedTransaction):
 
400
  neo4jExceptions.ServiceUnavailable,
401
  neo4jExceptions.TransientError,
402
  neo4jExceptions.WriteServiceUnavailable,
403
+ neo4jExceptions.ClientError,
404
  )
405
  ),
406
  )
 
415
  target_node_id (str): Label of the target node (used as identifier)
416
  edge_data (dict): Dictionary of properties to set on the edge
417
  """
418
+ source_label = await self._ensure_label(source_node_id)
419
+ target_label = await self._ensure_label(target_node_id)
420
  edge_properties = edge_data
421
 
422
  async def _do_upsert_edge(tx: AsyncManagedTransaction):
423
  query = f"""
424
+ MATCH (source:`{source_label}`)
425
  WITH source
426
+ MATCH (target:`{target_label}`)
427
  MERGE (source)-[r:DIRECTED]->(target)
428
  SET r += $properties
429
  RETURN r
430
  """
431
+ result = await tx.run(query, properties=edge_properties)
432
+ record = await result.single()
433
  logger.debug(
434
+ f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
435
  )
436
 
437
  try:
lightrag/lightrag.py CHANGED
@@ -1,10 +1,12 @@
 
 
1
  import asyncio
2
  import os
3
  import configparser
4
  from dataclasses import asdict, dataclass, field
5
  from datetime import datetime
6
  from functools import partial
7
- from typing import Any, Callable, Optional, Type, Union, cast
8
 
9
  from .base import (
10
  BaseGraphStorage,
@@ -76,6 +78,7 @@ STORAGE_IMPLEMENTATIONS = {
76
  "FaissVectorDBStorage",
77
  "QdrantVectorDBStorage",
78
  "OracleVectorDBStorage",
 
79
  ],
80
  "required_methods": ["query", "upsert"],
81
  },
@@ -91,7 +94,7 @@ STORAGE_IMPLEMENTATIONS = {
91
  }
92
 
93
  # Storage implementation environment variable without default value
94
- STORAGE_ENV_REQUIREMENTS = {
95
  # KV Storage Implementations
96
  "JsonKVStorage": [],
97
  "MongoKVStorage": [],
@@ -140,6 +143,7 @@ STORAGE_ENV_REQUIREMENTS = {
140
  "ORACLE_PASSWORD",
141
  "ORACLE_CONFIG_DIR",
142
  ],
 
143
  # Document Status Storage Implementations
144
  "JsonDocStatusStorage": [],
145
  "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
@@ -160,6 +164,7 @@ STORAGES = {
160
  "MongoKVStorage": ".kg.mongo_impl",
161
  "MongoDocStatusStorage": ".kg.mongo_impl",
162
  "MongoGraphStorage": ".kg.mongo_impl",
 
163
  "RedisKVStorage": ".kg.redis_impl",
164
  "ChromaVectorDBStorage": ".kg.chroma_impl",
165
  "TiDBKVStorage": ".kg.tidb_impl",
@@ -176,7 +181,7 @@ STORAGES = {
176
  }
177
 
178
 
179
- def lazy_external_import(module_name: str, class_name: str):
180
  """Lazily import a class from an external module based on the package of the caller."""
181
  # Get the caller's module and package
182
  import inspect
@@ -185,7 +190,7 @@ def lazy_external_import(module_name: str, class_name: str):
185
  module = inspect.getmodule(caller_frame)
186
  package = module.__package__ if module else None
187
 
188
- def import_class(*args, **kwargs):
189
  import importlib
190
 
191
  module = importlib.import_module(module_name, package=package)
@@ -302,7 +307,7 @@ class LightRAG:
302
  - random_seed: Seed value for reproducibility.
303
  """
304
 
305
- embedding_func: EmbeddingFunc = None
306
  """Function for computing text embeddings. Must be set before use."""
307
 
308
  embedding_batch_num: int = 32
@@ -312,7 +317,7 @@ class LightRAG:
312
  """Maximum number of concurrent embedding function calls."""
313
 
314
  # LLM Configuration
315
- llm_model_func: callable = None
316
  """Function for interacting with the large language model (LLM). Must be set before use."""
317
 
318
  llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
@@ -342,10 +347,8 @@ class LightRAG:
342
 
343
  # Extensions
344
  addon_params: dict[str, Any] = field(default_factory=dict)
345
- """Dictionary for additional parameters and extensions."""
346
 
347
- # extension
348
- addon_params: dict[str, Any] = field(default_factory=dict)
349
  convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
350
  convert_response_to_json
351
  )
@@ -354,7 +357,7 @@ class LightRAG:
354
  chunking_func: Callable[
355
  [
356
  str,
357
- Optional[str],
358
  bool,
359
  int,
360
  int,
@@ -443,77 +446,74 @@ class LightRAG:
443
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
444
 
445
  # Init LLM
446
- self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
447
  self.embedding_func
448
  )
449
 
450
  # Initialize all storages
451
- self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
452
  self._get_storage_class(self.kv_storage)
453
- )
454
- self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
455
  self.vector_storage
456
- )
457
- self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
458
  self.graph_storage
459
- )
460
-
461
- self.key_string_value_json_storage_cls = partial(
462
  self.key_string_value_json_storage_cls, global_config=global_config
463
  )
464
-
465
- self.vector_db_storage_cls = partial(
466
  self.vector_db_storage_cls, global_config=global_config
467
  )
468
-
469
- self.graph_storage_cls = partial(
470
  self.graph_storage_cls, global_config=global_config
471
  )
472
 
473
  # Initialize document status storage
474
  self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
475
 
476
- self.llm_response_cache = self.key_string_value_json_storage_cls(
477
  namespace=make_namespace(
478
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
479
  ),
480
  embedding_func=self.embedding_func,
481
  )
482
 
483
- self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls(
484
  namespace=make_namespace(
485
  self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
486
  ),
487
  embedding_func=self.embedding_func,
488
  )
489
- self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls(
490
  namespace=make_namespace(
491
  self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
492
  ),
493
  embedding_func=self.embedding_func,
494
  )
495
- self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls(
496
  namespace=make_namespace(
497
  self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
498
  ),
499
  embedding_func=self.embedding_func,
500
  )
501
 
502
- self.entities_vdb = self.vector_db_storage_cls(
503
  namespace=make_namespace(
504
  self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
505
  ),
506
  embedding_func=self.embedding_func,
507
  meta_fields={"entity_name"},
508
  )
509
- self.relationships_vdb = self.vector_db_storage_cls(
510
  namespace=make_namespace(
511
  self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
512
  ),
513
  embedding_func=self.embedding_func,
514
  meta_fields={"src_id", "tgt_id"},
515
  )
516
- self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls(
517
  namespace=make_namespace(
518
  self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
519
  ),
@@ -527,13 +527,12 @@ class LightRAG:
527
  embedding_func=None,
528
  )
529
 
530
- # What's for, Is this nessisary ?
531
  if self.llm_response_cache and hasattr(
532
  self.llm_response_cache, "global_config"
533
  ):
534
  hashing_kv = self.llm_response_cache
535
  else:
536
- hashing_kv = self.key_string_value_json_storage_cls(
537
  namespace=make_namespace(
538
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
539
  ),
@@ -542,7 +541,7 @@ class LightRAG:
542
 
543
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
544
  partial(
545
- self.llm_model_func,
546
  hashing_kv=hashing_kv,
547
  **self.llm_model_kwargs,
548
  )
@@ -559,68 +558,45 @@ class LightRAG:
559
  node_label=nodel_label, max_depth=max_depth
560
  )
561
 
562
- def _get_storage_class(self, storage_name: str) -> dict:
563
  import_path = STORAGES[storage_name]
564
  storage_class = lazy_external_import(import_path, storage_name)
565
  return storage_class
566
 
567
- def set_storage_client(self, db_client):
568
- # Deprecated, seting correct value to *_storage of LightRAG insteaded
569
- # Inject db to storage implementation (only tested on Oracle Database)
570
- for storage in [
571
- self.vector_db_storage_cls,
572
- self.graph_storage_cls,
573
- self.doc_status,
574
- self.full_docs,
575
- self.text_chunks,
576
- self.llm_response_cache,
577
- self.key_string_value_json_storage_cls,
578
- self.chunks_vdb,
579
- self.relationships_vdb,
580
- self.entities_vdb,
581
- self.graph_storage_cls,
582
- self.chunk_entity_relation_graph,
583
- self.llm_response_cache,
584
- ]:
585
- # set client
586
- storage.db = db_client
587
-
588
  def insert(
589
  self,
590
- string_or_strings: Union[str, list[str]],
591
  split_by_character: str | None = None,
592
  split_by_character_only: bool = False,
593
  ):
594
  """Sync Insert documents with checkpoint support
595
 
596
  Args:
597
- string_or_strings: Single document string or list of document strings
598
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
599
- chunk_size, split the sub chunk by token size.
600
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
601
  split_by_character is None, this parameter is ignored.
602
  """
603
  loop = always_get_an_event_loop()
604
  return loop.run_until_complete(
605
- self.ainsert(string_or_strings, split_by_character, split_by_character_only)
606
  )
607
 
608
  async def ainsert(
609
  self,
610
- string_or_strings: Union[str, list[str]],
611
  split_by_character: str | None = None,
612
  split_by_character_only: bool = False,
613
  ):
614
  """Async Insert documents with checkpoint support
615
 
616
  Args:
617
- string_or_strings: Single document string or list of document strings
618
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
619
- chunk_size, split the sub chunk by token size.
620
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
621
  split_by_character is None, this parameter is ignored.
622
  """
623
- await self.apipeline_enqueue_documents(string_or_strings)
624
  await self.apipeline_process_enqueue_documents(
625
  split_by_character, split_by_character_only
626
  )
@@ -677,7 +653,7 @@ class LightRAG:
677
  if update_storage:
678
  await self._insert_done()
679
 
680
- async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]):
681
  """
682
  Pipeline for Processing Documents
683
 
@@ -686,11 +662,11 @@ class LightRAG:
686
  3. Filter out already processed documents
687
  4. Enqueue document in status
688
  """
689
- if isinstance(string_or_strings, str):
690
- string_or_strings = [string_or_strings]
691
 
692
  # 1. Remove duplicate contents from the list
693
- unique_contents = list(set(doc.strip() for doc in string_or_strings))
694
 
695
  # 2. Generate document IDs and initial status
696
  new_docs: dict[str, Any] = {
@@ -857,32 +833,32 @@ class LightRAG:
857
  raise e
858
 
859
  async def _insert_done(self):
860
- tasks = []
861
- for storage_inst in [
862
- self.full_docs,
863
- self.text_chunks,
864
- self.llm_response_cache,
865
- self.entities_vdb,
866
- self.relationships_vdb,
867
- self.chunks_vdb,
868
- self.chunk_entity_relation_graph,
869
- ]:
870
- if storage_inst is None:
871
- continue
872
- tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
873
  await asyncio.gather(*tasks)
874
 
875
- def insert_custom_kg(self, custom_kg: dict):
876
  loop = always_get_an_event_loop()
877
  return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
878
 
879
- async def ainsert_custom_kg(self, custom_kg: dict):
880
  update_storage = False
881
  try:
882
  # Insert chunks into vector storage
883
- all_chunks_data = {}
884
- chunk_to_source_map = {}
885
- for chunk_data in custom_kg.get("chunks", []):
886
  chunk_content = chunk_data["content"]
887
  source_id = chunk_data["source_id"]
888
  chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
@@ -892,13 +868,13 @@ class LightRAG:
892
  chunk_to_source_map[source_id] = chunk_id
893
  update_storage = True
894
 
895
- if self.chunks_vdb is not None and all_chunks_data:
896
  await self.chunks_vdb.upsert(all_chunks_data)
897
- if self.text_chunks is not None and all_chunks_data:
898
  await self.text_chunks.upsert(all_chunks_data)
899
 
900
  # Insert entities into knowledge graph
901
- all_entities_data = []
902
  for entity_data in custom_kg.get("entities", []):
903
  entity_name = f'"{entity_data["entity_name"].upper()}"'
904
  entity_type = entity_data.get("entity_type", "UNKNOWN")
@@ -914,7 +890,7 @@ class LightRAG:
914
  )
915
 
916
  # Prepare node data
917
- node_data = {
918
  "entity_type": entity_type,
919
  "description": description,
920
  "source_id": source_id,
@@ -928,7 +904,7 @@ class LightRAG:
928
  update_storage = True
929
 
930
  # Insert relationships into knowledge graph
931
- all_relationships_data = []
932
  for relationship_data in custom_kg.get("relationships", []):
933
  src_id = f'"{relationship_data["src_id"].upper()}"'
934
  tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
@@ -970,7 +946,7 @@ class LightRAG:
970
  "source_id": source_id,
971
  },
972
  )
973
- edge_data = {
974
  "src_id": src_id,
975
  "tgt_id": tgt_id,
976
  "description": description,
@@ -980,41 +956,68 @@ class LightRAG:
980
  update_storage = True
981
 
982
  # Insert entities into vector storage if needed
983
- if self.entities_vdb is not None:
984
- data_for_vdb = {
985
- compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
986
- "content": dp["entity_name"] + dp["description"],
987
- "entity_name": dp["entity_name"],
988
- }
989
- for dp in all_entities_data
990
  }
991
- await self.entities_vdb.upsert(data_for_vdb)
 
 
992
 
993
  # Insert relationships into vector storage if needed
994
- if self.relationships_vdb is not None:
995
- data_for_vdb = {
996
- compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
997
- "src_id": dp["src_id"],
998
- "tgt_id": dp["tgt_id"],
999
- "content": dp["keywords"]
1000
- + dp["src_id"]
1001
- + dp["tgt_id"]
1002
- + dp["description"],
1003
- }
1004
- for dp in all_relationships_data
1005
  }
1006
- await self.relationships_vdb.upsert(data_for_vdb)
 
 
 
1007
  finally:
1008
  if update_storage:
1009
  await self._insert_done()
1010
 
1011
- def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
 
 
 
 
 
 
 
 
 
 
 
 
 
1012
  loop = always_get_an_event_loop()
1013
- return loop.run_until_complete(self.aquery(query, prompt, param))
 
1014
 
1015
  async def aquery(
1016
- self, query: str, prompt: str = "", param: QueryParam = QueryParam()
1017
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1018
  if param.mode in ["local", "global", "hybrid"]:
1019
  response = await kg_query(
1020
  query,
@@ -1094,7 +1097,7 @@ class LightRAG:
1094
 
1095
  async def aquery_with_separate_keyword_extraction(
1096
  self, query: str, prompt: str, param: QueryParam = QueryParam()
1097
- ):
1098
  """
1099
  1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
1100
  2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
@@ -1117,8 +1120,8 @@ class LightRAG:
1117
  ),
1118
  )
1119
 
1120
- param.hl_keywords = (hl_keywords,)
1121
- param.ll_keywords = (ll_keywords,)
1122
 
1123
  # ---------------------
1124
  # STEP 2: Final Query Logic
@@ -1146,7 +1149,7 @@ class LightRAG:
1146
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1147
  ),
1148
  global_config=asdict(self),
1149
- embedding_func=self.embedding_funcne,
1150
  ),
1151
  )
1152
  elif param.mode == "naive":
@@ -1195,12 +1198,7 @@ class LightRAG:
1195
  return response
1196
 
1197
  async def _query_done(self):
1198
- tasks = []
1199
- for storage_inst in [self.llm_response_cache]:
1200
- if storage_inst is None:
1201
- continue
1202
- tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
1203
- await asyncio.gather(*tasks)
1204
 
1205
  def delete_by_entity(self, entity_name: str):
1206
  loop = always_get_an_event_loop()
@@ -1222,16 +1220,16 @@ class LightRAG:
1222
  logger.error(f"Error while deleting entity '{entity_name}': {e}")
1223
 
1224
  async def _delete_by_entity_done(self):
1225
- tasks = []
1226
- for storage_inst in [
1227
- self.entities_vdb,
1228
- self.relationships_vdb,
1229
- self.chunk_entity_relation_graph,
1230
- ]:
1231
- if storage_inst is None:
1232
- continue
1233
- tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
1234
- await asyncio.gather(*tasks)
1235
 
1236
  def _get_content_summary(self, content: str, max_length: int = 100) -> str:
1237
  """Get summary of document content
@@ -1256,7 +1254,7 @@ class LightRAG:
1256
  """
1257
  return await self.doc_status.get_status_counts()
1258
 
1259
- async def adelete_by_doc_id(self, doc_id: str):
1260
  """Delete a document and all its related data
1261
 
1262
  Args:
@@ -1273,6 +1271,9 @@ class LightRAG:
1273
 
1274
  # 2. Get all related chunks
1275
  chunks = await self.text_chunks.get_by_id(doc_id)
 
 
 
1276
  chunk_ids = list(chunks.keys())
1277
  logger.debug(f"Found {len(chunk_ids)} chunks to delete")
1278
 
@@ -1443,13 +1444,9 @@ class LightRAG:
1443
  except Exception as e:
1444
  logger.error(f"Error while deleting document {doc_id}: {e}")
1445
 
1446
- def delete_by_doc_id(self, doc_id: str):
1447
- """Synchronous version of adelete"""
1448
- return asyncio.run(self.adelete_by_doc_id(doc_id))
1449
-
1450
  async def get_entity_info(
1451
  self, entity_name: str, include_vector_data: bool = False
1452
- ):
1453
  """Get detailed information of an entity
1454
 
1455
  Args:
@@ -1469,7 +1466,7 @@ class LightRAG:
1469
  node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
1470
  source_id = node_data.get("source_id") if node_data else None
1471
 
1472
- result = {
1473
  "entity_name": entity_name,
1474
  "source_id": source_id,
1475
  "graph_data": node_data,
@@ -1483,21 +1480,6 @@ class LightRAG:
1483
 
1484
  return result
1485
 
1486
- def get_entity_info_sync(self, entity_name: str, include_vector_data: bool = False):
1487
- """Synchronous version of getting entity information
1488
-
1489
- Args:
1490
- entity_name: Entity name (no need for quotes)
1491
- include_vector_data: Whether to include data from the vector database
1492
- """
1493
- try:
1494
- import tracemalloc
1495
-
1496
- tracemalloc.start()
1497
- return asyncio.run(self.get_entity_info(entity_name, include_vector_data))
1498
- finally:
1499
- tracemalloc.stop()
1500
-
1501
  async def get_relation_info(
1502
  self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
1503
  ):
@@ -1525,7 +1507,7 @@ class LightRAG:
1525
  )
1526
  source_id = edge_data.get("source_id") if edge_data else None
1527
 
1528
- result = {
1529
  "src_entity": src_entity,
1530
  "tgt_entity": tgt_entity,
1531
  "source_id": source_id,
@@ -1539,23 +1521,3 @@ class LightRAG:
1539
  result["vector_data"] = vector_data[0] if vector_data else None
1540
 
1541
  return result
1542
-
1543
- def get_relation_info_sync(
1544
- self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
1545
- ):
1546
- """Synchronous version of getting relationship information
1547
-
1548
- Args:
1549
- src_entity: Source entity name (no need for quotes)
1550
- tgt_entity: Target entity name (no need for quotes)
1551
- include_vector_data: Whether to include data from the vector database
1552
- """
1553
- try:
1554
- import tracemalloc
1555
-
1556
- tracemalloc.start()
1557
- return asyncio.run(
1558
- self.get_relation_info(src_entity, tgt_entity, include_vector_data)
1559
- )
1560
- finally:
1561
- tracemalloc.stop()
 
1
+ from __future__ import annotations
2
+
3
  import asyncio
4
  import os
5
  import configparser
6
  from dataclasses import asdict, dataclass, field
7
  from datetime import datetime
8
  from functools import partial
9
+ from typing import Any, AsyncIterator, Callable, Iterator, cast
10
 
11
  from .base import (
12
  BaseGraphStorage,
 
78
  "FaissVectorDBStorage",
79
  "QdrantVectorDBStorage",
80
  "OracleVectorDBStorage",
81
+ "MongoVectorDBStorage",
82
  ],
83
  "required_methods": ["query", "upsert"],
84
  },
 
94
  }
95
 
96
  # Storage implementation environment variable without default value
97
+ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
98
  # KV Storage Implementations
99
  "JsonKVStorage": [],
100
  "MongoKVStorage": [],
 
143
  "ORACLE_PASSWORD",
144
  "ORACLE_CONFIG_DIR",
145
  ],
146
+ "MongoVectorDBStorage": [],
147
  # Document Status Storage Implementations
148
  "JsonDocStatusStorage": [],
149
  "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
 
164
  "MongoKVStorage": ".kg.mongo_impl",
165
  "MongoDocStatusStorage": ".kg.mongo_impl",
166
  "MongoGraphStorage": ".kg.mongo_impl",
167
+ "MongoVectorDBStorage": ".kg.mongo_impl",
168
  "RedisKVStorage": ".kg.redis_impl",
169
  "ChromaVectorDBStorage": ".kg.chroma_impl",
170
  "TiDBKVStorage": ".kg.tidb_impl",
 
181
  }
182
 
183
 
184
+ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
185
  """Lazily import a class from an external module based on the package of the caller."""
186
  # Get the caller's module and package
187
  import inspect
 
190
  module = inspect.getmodule(caller_frame)
191
  package = module.__package__ if module else None
192
 
193
+ def import_class(*args: Any, **kwargs: Any):
194
  import importlib
195
 
196
  module = importlib.import_module(module_name, package=package)
 
307
  - random_seed: Seed value for reproducibility.
308
  """
309
 
310
+ embedding_func: EmbeddingFunc | None = None
311
  """Function for computing text embeddings. Must be set before use."""
312
 
313
  embedding_batch_num: int = 32
 
317
  """Maximum number of concurrent embedding function calls."""
318
 
319
  # LLM Configuration
320
+ llm_model_func: Callable[..., object] | None = None
321
  """Function for interacting with the large language model (LLM). Must be set before use."""
322
 
323
  llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
 
347
 
348
  # Extensions
349
  addon_params: dict[str, Any] = field(default_factory=dict)
 
350
 
351
+ """Dictionary for additional parameters and extensions."""
 
352
  convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
353
  convert_response_to_json
354
  )
 
357
  chunking_func: Callable[
358
  [
359
  str,
360
+ str | None,
361
  bool,
362
  int,
363
  int,
 
446
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
447
 
448
  # Init LLM
449
+ self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
450
  self.embedding_func
451
  )
452
 
453
  # Initialize all storages
454
+ self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
455
  self._get_storage_class(self.kv_storage)
456
+ ) # type: ignore
457
+ self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
458
  self.vector_storage
459
+ ) # type: ignore
460
+ self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
461
  self.graph_storage
462
+ ) # type: ignore
463
+ self.key_string_value_json_storage_cls = partial( # type: ignore
 
464
  self.key_string_value_json_storage_cls, global_config=global_config
465
  )
466
+ self.vector_db_storage_cls = partial( # type: ignore
 
467
  self.vector_db_storage_cls, global_config=global_config
468
  )
469
+ self.graph_storage_cls = partial( # type: ignore
 
470
  self.graph_storage_cls, global_config=global_config
471
  )
472
 
473
  # Initialize document status storage
474
  self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
475
 
476
+ self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
477
  namespace=make_namespace(
478
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
479
  ),
480
  embedding_func=self.embedding_func,
481
  )
482
 
483
+ self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
484
  namespace=make_namespace(
485
  self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
486
  ),
487
  embedding_func=self.embedding_func,
488
  )
489
+ self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
490
  namespace=make_namespace(
491
  self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
492
  ),
493
  embedding_func=self.embedding_func,
494
  )
495
+ self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
496
  namespace=make_namespace(
497
  self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
498
  ),
499
  embedding_func=self.embedding_func,
500
  )
501
 
502
+ self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
503
  namespace=make_namespace(
504
  self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
505
  ),
506
  embedding_func=self.embedding_func,
507
  meta_fields={"entity_name"},
508
  )
509
+ self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
510
  namespace=make_namespace(
511
  self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
512
  ),
513
  embedding_func=self.embedding_func,
514
  meta_fields={"src_id", "tgt_id"},
515
  )
516
+ self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
517
  namespace=make_namespace(
518
  self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
519
  ),
 
527
  embedding_func=None,
528
  )
529
 
 
530
  if self.llm_response_cache and hasattr(
531
  self.llm_response_cache, "global_config"
532
  ):
533
  hashing_kv = self.llm_response_cache
534
  else:
535
+ hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
536
  namespace=make_namespace(
537
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
538
  ),
 
541
 
542
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
543
  partial(
544
+ self.llm_model_func, # type: ignore
545
  hashing_kv=hashing_kv,
546
  **self.llm_model_kwargs,
547
  )
 
558
  node_label=nodel_label, max_depth=max_depth
559
  )
560
 
561
+ def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
562
  import_path = STORAGES[storage_name]
563
  storage_class = lazy_external_import(import_path, storage_name)
564
  return storage_class
565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  def insert(
567
  self,
568
+ input: str | list[str],
569
  split_by_character: str | None = None,
570
  split_by_character_only: bool = False,
571
  ):
572
  """Sync Insert documents with checkpoint support
573
 
574
  Args:
575
+ input: Single document string or list of document strings
576
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
 
577
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
578
  split_by_character is None, this parameter is ignored.
579
  """
580
  loop = always_get_an_event_loop()
581
  return loop.run_until_complete(
582
+ self.ainsert(input, split_by_character, split_by_character_only)
583
  )
584
 
585
  async def ainsert(
586
  self,
587
+ input: str | list[str],
588
  split_by_character: str | None = None,
589
  split_by_character_only: bool = False,
590
  ):
591
  """Async Insert documents with checkpoint support
592
 
593
  Args:
594
+ input: Single document string or list of document strings
595
  split_by_character: if split_by_character is not None, split the string by character, if chunk longer than
 
596
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
597
  split_by_character is None, this parameter is ignored.
598
  """
599
+ await self.apipeline_enqueue_documents(input)
600
  await self.apipeline_process_enqueue_documents(
601
  split_by_character, split_by_character_only
602
  )
 
653
  if update_storage:
654
  await self._insert_done()
655
 
656
+ async def apipeline_enqueue_documents(self, input: str | list[str]):
657
  """
658
  Pipeline for Processing Documents
659
 
 
662
  3. Filter out already processed documents
663
  4. Enqueue document in status
664
  """
665
+ if isinstance(input, str):
666
+ input = [input]
667
 
668
  # 1. Remove duplicate contents from the list
669
+ unique_contents = list(set(doc.strip() for doc in input))
670
 
671
  # 2. Generate document IDs and initial status
672
  new_docs: dict[str, Any] = {
 
833
  raise e
834
 
835
  async def _insert_done(self):
836
+ tasks = [
837
+ cast(StorageNameSpace, storage_inst).index_done_callback()
838
+ for storage_inst in [ # type: ignore
839
+ self.full_docs,
840
+ self.text_chunks,
841
+ self.llm_response_cache,
842
+ self.entities_vdb,
843
+ self.relationships_vdb,
844
+ self.chunks_vdb,
845
+ self.chunk_entity_relation_graph,
846
+ ]
847
+ if storage_inst is not None
848
+ ]
849
  await asyncio.gather(*tasks)
850
 
851
+ def insert_custom_kg(self, custom_kg: dict[str, Any]):
852
  loop = always_get_an_event_loop()
853
  return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
854
 
855
+ async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
856
  update_storage = False
857
  try:
858
  # Insert chunks into vector storage
859
+ all_chunks_data: dict[str, dict[str, str]] = {}
860
+ chunk_to_source_map: dict[str, str] = {}
861
+ for chunk_data in custom_kg.get("chunks", {}):
862
  chunk_content = chunk_data["content"]
863
  source_id = chunk_data["source_id"]
864
  chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
 
868
  chunk_to_source_map[source_id] = chunk_id
869
  update_storage = True
870
 
871
+ if all_chunks_data:
872
  await self.chunks_vdb.upsert(all_chunks_data)
873
+ if all_chunks_data:
874
  await self.text_chunks.upsert(all_chunks_data)
875
 
876
  # Insert entities into knowledge graph
877
+ all_entities_data: list[dict[str, str]] = []
878
  for entity_data in custom_kg.get("entities", []):
879
  entity_name = f'"{entity_data["entity_name"].upper()}"'
880
  entity_type = entity_data.get("entity_type", "UNKNOWN")
 
890
  )
891
 
892
  # Prepare node data
893
+ node_data: dict[str, str] = {
894
  "entity_type": entity_type,
895
  "description": description,
896
  "source_id": source_id,
 
904
  update_storage = True
905
 
906
  # Insert relationships into knowledge graph
907
+ all_relationships_data: list[dict[str, str]] = []
908
  for relationship_data in custom_kg.get("relationships", []):
909
  src_id = f'"{relationship_data["src_id"].upper()}"'
910
  tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
 
946
  "source_id": source_id,
947
  },
948
  )
949
+ edge_data: dict[str, str] = {
950
  "src_id": src_id,
951
  "tgt_id": tgt_id,
952
  "description": description,
 
956
  update_storage = True
957
 
958
  # Insert entities into vector storage if needed
959
+ data_for_vdb = {
960
+ compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
961
+ "content": dp["entity_name"] + dp["description"],
962
+ "entity_name": dp["entity_name"],
 
 
 
963
  }
964
+ for dp in all_entities_data
965
+ }
966
+ await self.entities_vdb.upsert(data_for_vdb)
967
 
968
  # Insert relationships into vector storage if needed
969
+ data_for_vdb = {
970
+ compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
971
+ "src_id": dp["src_id"],
972
+ "tgt_id": dp["tgt_id"],
973
+ "content": dp["keywords"]
974
+ + dp["src_id"]
975
+ + dp["tgt_id"]
976
+ + dp["description"],
 
 
 
977
  }
978
+ for dp in all_relationships_data
979
+ }
980
+ await self.relationships_vdb.upsert(data_for_vdb)
981
+
982
  finally:
983
  if update_storage:
984
  await self._insert_done()
985
 
986
+ def query(
987
+ self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
988
+ ) -> str | Iterator[str]:
989
+ """
990
+ Perform a sync query.
991
+
992
+ Args:
993
+ query (str): The query to be executed.
994
+ param (QueryParam): Configuration parameters for query execution.
995
+ prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
996
+
997
+ Returns:
998
+ str: The result of the query execution.
999
+ """
1000
  loop = always_get_an_event_loop()
1001
+
1002
+ return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
1003
 
1004
  async def aquery(
1005
+ self,
1006
+ query: str,
1007
+ param: QueryParam = QueryParam(),
1008
+ prompt: str | None = None,
1009
+ ) -> str | AsyncIterator[str]:
1010
+ """
1011
+ Perform a async query.
1012
+
1013
+ Args:
1014
+ query (str): The query to be executed.
1015
+ param (QueryParam): Configuration parameters for query execution.
1016
+ prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
1017
+
1018
+ Returns:
1019
+ str: The result of the query execution.
1020
+ """
1021
  if param.mode in ["local", "global", "hybrid"]:
1022
  response = await kg_query(
1023
  query,
 
1097
 
1098
  async def aquery_with_separate_keyword_extraction(
1099
  self, query: str, prompt: str, param: QueryParam = QueryParam()
1100
+ ) -> str | AsyncIterator[str]:
1101
  """
1102
  1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
1103
  2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
 
1120
  ),
1121
  )
1122
 
1123
+ param.hl_keywords = hl_keywords
1124
+ param.ll_keywords = ll_keywords
1125
 
1126
  # ---------------------
1127
  # STEP 2: Final Query Logic
 
1149
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1150
  ),
1151
  global_config=asdict(self),
1152
+ embedding_func=self.embedding_func,
1153
  ),
1154
  )
1155
  elif param.mode == "naive":
 
1198
  return response
1199
 
1200
  async def _query_done(self):
1201
+ await self.llm_response_cache.index_done_callback()
 
 
 
 
 
1202
 
1203
  def delete_by_entity(self, entity_name: str):
1204
  loop = always_get_an_event_loop()
 
1220
  logger.error(f"Error while deleting entity '{entity_name}': {e}")
1221
 
1222
  async def _delete_by_entity_done(self):
1223
+ await asyncio.gather(
1224
+ *[
1225
+ cast(StorageNameSpace, storage_inst).index_done_callback()
1226
+ for storage_inst in [ # type: ignore
1227
+ self.entities_vdb,
1228
+ self.relationships_vdb,
1229
+ self.chunk_entity_relation_graph,
1230
+ ]
1231
+ ]
1232
+ )
1233
 
1234
  def _get_content_summary(self, content: str, max_length: int = 100) -> str:
1235
  """Get summary of document content
 
1254
  """
1255
  return await self.doc_status.get_status_counts()
1256
 
1257
+ async def adelete_by_doc_id(self, doc_id: str) -> None:
1258
  """Delete a document and all its related data
1259
 
1260
  Args:
 
1271
 
1272
  # 2. Get all related chunks
1273
  chunks = await self.text_chunks.get_by_id(doc_id)
1274
+ if not chunks:
1275
+ return
1276
+
1277
  chunk_ids = list(chunks.keys())
1278
  logger.debug(f"Found {len(chunk_ids)} chunks to delete")
1279
 
 
1444
  except Exception as e:
1445
  logger.error(f"Error while deleting document {doc_id}: {e}")
1446
 
 
 
 
 
1447
  async def get_entity_info(
1448
  self, entity_name: str, include_vector_data: bool = False
1449
+ ) -> dict[str, str | None | dict[str, str]]:
1450
  """Get detailed information of an entity
1451
 
1452
  Args:
 
1466
  node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
1467
  source_id = node_data.get("source_id") if node_data else None
1468
 
1469
+ result: dict[str, str | None | dict[str, str]] = {
1470
  "entity_name": entity_name,
1471
  "source_id": source_id,
1472
  "graph_data": node_data,
 
1480
 
1481
  return result
1482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1483
  async def get_relation_info(
1484
  self, src_entity: str, tgt_entity: str, include_vector_data: bool = False
1485
  ):
 
1507
  )
1508
  source_id = edge_data.get("source_id") if edge_data else None
1509
 
1510
+ result: dict[str, str | None | dict[str, str]] = {
1511
  "src_entity": src_entity,
1512
  "tgt_entity": tgt_entity,
1513
  "source_id": source_id,
 
1521
  result["vector_data"] = vector_data[0] if vector_data else None
1522
 
1523
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightrag/llm.py CHANGED
@@ -1,4 +1,6 @@
1
- from typing import List, Dict, Callable, Any
 
 
2
  from pydantic import BaseModel, Field
3
 
4
 
@@ -23,7 +25,7 @@ class Model(BaseModel):
23
  ...,
24
  description="A function that generates the response from the llm. The response must be a string",
25
  )
26
- kwargs: Dict[str, Any] = Field(
27
  ...,
28
  description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
29
  )
@@ -57,7 +59,7 @@ class MultiModel:
57
  ```
58
  """
59
 
60
- def __init__(self, models: List[Model]):
61
  self._models = models
62
  self._current_model = 0
63
 
@@ -66,7 +68,11 @@ class MultiModel:
66
  return self._models[self._current_model]
67
 
68
  async def llm_model_func(
69
- self, prompt, system_prompt=None, history_messages=[], **kwargs
 
 
 
 
70
  ) -> str:
71
  kwargs.pop("model", None) # stop from overwriting the custom model name
72
  kwargs.pop("keyword_extraction", None)
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable, Any
4
  from pydantic import BaseModel, Field
5
 
6
 
 
25
  ...,
26
  description="A function that generates the response from the llm. The response must be a string",
27
  )
28
+ kwargs: dict[str, Any] = Field(
29
  ...,
30
  description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
31
  )
 
59
  ```
60
  """
61
 
62
+ def __init__(self, models: list[Model]):
63
  self._models = models
64
  self._current_model = 0
65
 
 
68
  return self._models[self._current_model]
69
 
70
  async def llm_model_func(
71
+ self,
72
+ prompt: str,
73
+ system_prompt: str | None = None,
74
+ history_messages: list[dict[str, Any]] = [],
75
+ **kwargs: Any,
76
  ) -> str:
77
  kwargs.pop("model", None) # stop from overwriting the custom model name
78
  kwargs.pop("keyword_extraction", None)
lightrag/namespace.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from typing import Iterable
2
 
3
 
 
1
+ from __future__ import annotations
2
+
3
  from typing import Iterable
4
 
5
 
lightrag/operate.py CHANGED
@@ -1,8 +1,10 @@
 
 
1
  import asyncio
2
  import json
3
  import re
4
  from tqdm.asyncio import tqdm as tqdm_async
5
- from typing import Any, Union
6
  from collections import Counter, defaultdict
7
  from .utils import (
8
  logger,
@@ -36,7 +38,7 @@ import time
36
 
37
  def chunking_by_token_size(
38
  content: str,
39
- split_by_character: Union[str, None] = None,
40
  split_by_character_only: bool = False,
41
  overlap_token_size: int = 128,
42
  max_token_size: int = 1024,
@@ -237,25 +239,65 @@ async def _merge_edges_then_upsert(
237
 
238
  if await knowledge_graph_inst.has_edge(src_id, tgt_id):
239
  already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
240
- already_weights.append(already_edge["weight"])
241
- already_source_ids.extend(
242
- split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
243
- )
244
- already_description.append(already_edge["description"])
245
- already_keywords.extend(
246
- split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
247
- )
 
 
 
 
 
 
 
 
 
 
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  weight = sum([dp["weight"] for dp in edges_data] + already_weights)
250
  description = GRAPH_FIELD_SEP.join(
251
- sorted(set([dp["description"] for dp in edges_data] + already_description))
 
 
 
 
 
252
  )
253
  keywords = GRAPH_FIELD_SEP.join(
254
- sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
 
 
 
 
 
255
  )
256
  source_id = GRAPH_FIELD_SEP.join(
257
- set([dp["source_id"] for dp in edges_data] + already_source_ids)
 
 
 
258
  )
 
259
  for need_insert_id in [src_id, tgt_id]:
260
  if not (await knowledge_graph_inst.has_node(need_insert_id)):
261
  await knowledge_graph_inst.upsert_node(
@@ -295,9 +337,9 @@ async def extract_entities(
295
  knowledge_graph_inst: BaseGraphStorage,
296
  entity_vdb: BaseVectorStorage,
297
  relationships_vdb: BaseVectorStorage,
298
- global_config: dict,
299
- llm_response_cache: BaseKVStorage = None,
300
- ) -> Union[BaseGraphStorage, None]:
301
  use_llm_func: callable = global_config["llm_model_func"]
302
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
303
  enable_llm_cache_for_entity_extract: bool = global_config[
@@ -563,15 +605,15 @@ async def extract_entities(
563
 
564
 
565
  async def kg_query(
566
- query,
567
  knowledge_graph_inst: BaseGraphStorage,
568
  entities_vdb: BaseVectorStorage,
569
  relationships_vdb: BaseVectorStorage,
570
  text_chunks_db: BaseKVStorage,
571
  query_param: QueryParam,
572
- global_config: dict,
573
- hashing_kv: BaseKVStorage = None,
574
- prompt: str = "",
575
  ) -> str:
576
  # Handle cache
577
  use_model_func = global_config["llm_model_func"]
@@ -684,8 +726,8 @@ async def kg_query(
684
  async def extract_keywords_only(
685
  text: str,
686
  param: QueryParam,
687
- global_config: dict,
688
- hashing_kv: BaseKVStorage = None,
689
  ) -> tuple[list[str], list[str]]:
690
  """
691
  Extract high-level and low-level keywords from the given 'text' using the LLM.
@@ -784,9 +826,9 @@ async def mix_kg_vector_query(
784
  chunks_vdb: BaseVectorStorage,
785
  text_chunks_db: BaseKVStorage,
786
  query_param: QueryParam,
787
- global_config: dict,
788
- hashing_kv: BaseKVStorage = None,
789
- ) -> str:
790
  """
791
  Hybrid retrieval implementation combining knowledge graph and vector search.
792
 
@@ -1551,13 +1593,13 @@ def combine_contexts(entities, relationships, sources):
1551
 
1552
 
1553
  async def naive_query(
1554
- query,
1555
  chunks_vdb: BaseVectorStorage,
1556
  text_chunks_db: BaseKVStorage,
1557
  query_param: QueryParam,
1558
- global_config: dict,
1559
- hashing_kv: BaseKVStorage = None,
1560
- ):
1561
  # Handle cache
1562
  use_model_func = global_config["llm_model_func"]
1563
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
@@ -1664,9 +1706,9 @@ async def kg_query_with_keywords(
1664
  relationships_vdb: BaseVectorStorage,
1665
  text_chunks_db: BaseKVStorage,
1666
  query_param: QueryParam,
1667
- global_config: dict,
1668
- hashing_kv: BaseKVStorage = None,
1669
- ) -> str:
1670
  """
1671
  Refactored kg_query that does NOT extract keywords by itself.
1672
  It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
 
1
+ from __future__ import annotations
2
+
3
  import asyncio
4
  import json
5
  import re
6
  from tqdm.asyncio import tqdm as tqdm_async
7
+ from typing import Any, AsyncIterator
8
  from collections import Counter, defaultdict
9
  from .utils import (
10
  logger,
 
38
 
39
  def chunking_by_token_size(
40
  content: str,
41
+ split_by_character: str | None = None,
42
  split_by_character_only: bool = False,
43
  overlap_token_size: int = 128,
44
  max_token_size: int = 1024,
 
239
 
240
  if await knowledge_graph_inst.has_edge(src_id, tgt_id):
241
  already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
242
+ # Handle the case where get_edge returns None or missing fields
243
+ if already_edge:
244
+ # Get weight with default 0.0 if missing
245
+ if "weight" in already_edge:
246
+ already_weights.append(already_edge["weight"])
247
+ else:
248
+ logger.warning(
249
+ f"Edge between {src_id} and {tgt_id} missing weight field"
250
+ )
251
+ already_weights.append(0.0)
252
+
253
+ # Get source_id with empty string default if missing or None
254
+ if "source_id" in already_edge and already_edge["source_id"] is not None:
255
+ already_source_ids.extend(
256
+ split_string_by_multi_markers(
257
+ already_edge["source_id"], [GRAPH_FIELD_SEP]
258
+ )
259
+ )
260
 
261
+ # Get description with empty string default if missing or None
262
+ if (
263
+ "description" in already_edge
264
+ and already_edge["description"] is not None
265
+ ):
266
+ already_description.append(already_edge["description"])
267
+
268
+ # Get keywords with empty string default if missing or None
269
+ if "keywords" in already_edge and already_edge["keywords"] is not None:
270
+ already_keywords.extend(
271
+ split_string_by_multi_markers(
272
+ already_edge["keywords"], [GRAPH_FIELD_SEP]
273
+ )
274
+ )
275
+
276
+ # Process edges_data with None checks
277
  weight = sum([dp["weight"] for dp in edges_data] + already_weights)
278
  description = GRAPH_FIELD_SEP.join(
279
+ sorted(
280
+ set(
281
+ [dp["description"] for dp in edges_data if dp.get("description")]
282
+ + already_description
283
+ )
284
+ )
285
  )
286
  keywords = GRAPH_FIELD_SEP.join(
287
+ sorted(
288
+ set(
289
+ [dp["keywords"] for dp in edges_data if dp.get("keywords")]
290
+ + already_keywords
291
+ )
292
+ )
293
  )
294
  source_id = GRAPH_FIELD_SEP.join(
295
+ set(
296
+ [dp["source_id"] for dp in edges_data if dp.get("source_id")]
297
+ + already_source_ids
298
+ )
299
  )
300
+
301
  for need_insert_id in [src_id, tgt_id]:
302
  if not (await knowledge_graph_inst.has_node(need_insert_id)):
303
  await knowledge_graph_inst.upsert_node(
 
337
  knowledge_graph_inst: BaseGraphStorage,
338
  entity_vdb: BaseVectorStorage,
339
  relationships_vdb: BaseVectorStorage,
340
+ global_config: dict[str, str],
341
+ llm_response_cache: BaseKVStorage | None = None,
342
+ ) -> BaseGraphStorage | None:
343
  use_llm_func: callable = global_config["llm_model_func"]
344
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
345
  enable_llm_cache_for_entity_extract: bool = global_config[
 
605
 
606
 
607
  async def kg_query(
608
+ query: str,
609
  knowledge_graph_inst: BaseGraphStorage,
610
  entities_vdb: BaseVectorStorage,
611
  relationships_vdb: BaseVectorStorage,
612
  text_chunks_db: BaseKVStorage,
613
  query_param: QueryParam,
614
+ global_config: dict[str, str],
615
+ hashing_kv: BaseKVStorage | None = None,
616
+ prompt: str | None = None,
617
  ) -> str:
618
  # Handle cache
619
  use_model_func = global_config["llm_model_func"]
 
726
  async def extract_keywords_only(
727
  text: str,
728
  param: QueryParam,
729
+ global_config: dict[str, str],
730
+ hashing_kv: BaseKVStorage | None = None,
731
  ) -> tuple[list[str], list[str]]:
732
  """
733
  Extract high-level and low-level keywords from the given 'text' using the LLM.
 
826
  chunks_vdb: BaseVectorStorage,
827
  text_chunks_db: BaseKVStorage,
828
  query_param: QueryParam,
829
+ global_config: dict[str, str],
830
+ hashing_kv: BaseKVStorage | None = None,
831
+ ) -> str | AsyncIterator[str]:
832
  """
833
  Hybrid retrieval implementation combining knowledge graph and vector search.
834
 
 
1593
 
1594
 
1595
  async def naive_query(
1596
+ query: str,
1597
  chunks_vdb: BaseVectorStorage,
1598
  text_chunks_db: BaseKVStorage,
1599
  query_param: QueryParam,
1600
+ global_config: dict[str, str],
1601
+ hashing_kv: BaseKVStorage | None = None,
1602
+ ) -> str | AsyncIterator[str]:
1603
  # Handle cache
1604
  use_model_func = global_config["llm_model_func"]
1605
  args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
 
1706
  relationships_vdb: BaseVectorStorage,
1707
  text_chunks_db: BaseKVStorage,
1708
  query_param: QueryParam,
1709
+ global_config: dict[str, str],
1710
+ hashing_kv: BaseKVStorage | None = None,
1711
+ ) -> str | AsyncIterator[str]:
1712
  """
1713
  Refactored kg_query that does NOT extract keywords by itself.
1714
  It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
lightrag/prompt.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  GRAPH_FIELD_SEP = "<SEP>"
2
 
3
  PROMPTS = {}
 
1
+ from __future__ import annotations
2
+
3
  GRAPH_FIELD_SEP = "<SEP>"
4
 
5
  PROMPTS = {}
lightrag/types.py CHANGED
@@ -1,26 +1,28 @@
 
 
1
  from pydantic import BaseModel
2
- from typing import List, Dict, Any
3
 
4
 
5
  class GPTKeywordExtractionFormat(BaseModel):
6
- high_level_keywords: List[str]
7
- low_level_keywords: List[str]
8
 
9
 
10
  class KnowledgeGraphNode(BaseModel):
11
  id: str
12
- labels: List[str]
13
- properties: Dict[str, Any] # anything else goes here
14
 
15
 
16
  class KnowledgeGraphEdge(BaseModel):
17
  id: str
18
- type: str
19
  source: str # id of source node
20
  target: str # id of target node
21
- properties: Dict[str, Any] # anything else goes here
22
 
23
 
24
  class KnowledgeGraph(BaseModel):
25
- nodes: List[KnowledgeGraphNode] = []
26
- edges: List[KnowledgeGraphEdge] = []
 
1
+ from __future__ import annotations
2
+
3
  from pydantic import BaseModel
4
+ from typing import Any, Optional
5
 
6
 
7
  class GPTKeywordExtractionFormat(BaseModel):
8
+ high_level_keywords: list[str]
9
+ low_level_keywords: list[str]
10
 
11
 
12
  class KnowledgeGraphNode(BaseModel):
13
  id: str
14
+ labels: list[str]
15
+ properties: dict[str, Any] # anything else goes here
16
 
17
 
18
  class KnowledgeGraphEdge(BaseModel):
19
  id: str
20
+ type: Optional[str]
21
  source: str # id of source node
22
  target: str # id of target node
23
+ properties: dict[str, Any] # anything else goes here
24
 
25
 
26
  class KnowledgeGraph(BaseModel):
27
+ nodes: list[KnowledgeGraphNode] = []
28
+ edges: list[KnowledgeGraphEdge] = []
lightrag/utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import asyncio
2
  import html
3
  import io
@@ -9,7 +11,7 @@ import re
9
  from dataclasses import dataclass
10
  from functools import wraps
11
  from hashlib import md5
12
- from typing import Any, Union, List, Optional
13
  import xml.etree.ElementTree as ET
14
  import bs4
15
 
@@ -67,12 +69,12 @@ class EmbeddingFunc:
67
 
68
  @dataclass
69
  class ReasoningResponse:
70
- reasoning_content: str
71
  response_content: str
72
  tag: str
73
 
74
 
75
- def locate_json_string_body_from_string(content: str) -> Union[str, None]:
76
  """Locate the JSON string body from a string"""
77
  try:
78
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -109,7 +111,7 @@ def convert_response_to_json(response: str) -> dict[str, Any]:
109
  raise e from None
110
 
111
 
112
- def compute_args_hash(*args, cache_type: str = None) -> str:
113
  """Compute a hash for the given arguments.
114
  Args:
115
  *args: Arguments to hash
@@ -128,7 +130,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
128
  return hashlib.md5(args_str.encode()).hexdigest()
129
 
130
 
131
- def compute_mdhash_id(content, prefix: str = ""):
 
 
 
 
 
132
  return prefix + md5(content.encode()).hexdigest()
133
 
134
 
@@ -215,11 +222,13 @@ def clean_str(input: Any) -> str:
215
  return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
216
 
217
 
218
- def is_float_regex(value):
219
  return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
220
 
221
 
222
- def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
 
 
223
  """Truncate a list of data by token size"""
224
  if max_token_size <= 0:
225
  return []
@@ -231,7 +240,7 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
231
  return list_data
232
 
233
 
234
- def list_of_list_to_csv(data: List[List[str]]) -> str:
235
  output = io.StringIO()
236
  writer = csv.writer(
237
  output,
@@ -244,7 +253,7 @@ def list_of_list_to_csv(data: List[List[str]]) -> str:
244
  return output.getvalue()
245
 
246
 
247
- def csv_string_to_list(csv_string: str) -> List[List[str]]:
248
  # Clean the string by removing NUL characters
249
  cleaned_string = csv_string.replace("\0", "")
250
 
@@ -329,7 +338,7 @@ def xml_to_json(xml_file):
329
  return None
330
 
331
 
332
- def process_combine_contexts(hl, ll):
333
  header = None
334
  list_hl = csv_string_to_list(hl.strip())
335
  list_ll = csv_string_to_list(ll.strip())
@@ -375,7 +384,7 @@ async def get_best_cached_response(
375
  llm_func=None,
376
  original_prompt=None,
377
  cache_type=None,
378
- ) -> Union[str, None]:
379
  logger.debug(
380
  f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
381
  )
@@ -479,7 +488,7 @@ def cosine_similarity(v1, v2):
479
  return dot_product / (norm1 * norm2)
480
 
481
 
482
- def quantize_embedding(embedding: Union[np.ndarray, list], bits=8) -> tuple:
483
  """Quantize embedding to specified bits"""
484
  # Convert list to numpy array if needed
485
  if isinstance(embedding, list):
@@ -570,9 +579,9 @@ class CacheData:
570
  args_hash: str
571
  content: str
572
  prompt: str
573
- quantized: Optional[np.ndarray] = None
574
- min_val: Optional[float] = None
575
- max_val: Optional[float] = None
576
  mode: str = "default"
577
  cache_type: str = "query"
578
 
@@ -635,7 +644,9 @@ def exists_func(obj, func_name: str) -> bool:
635
  return False
636
 
637
 
638
- def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str:
 
 
639
  """
640
  Process conversation history to get the specified number of complete turns.
641
 
@@ -647,8 +658,8 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
647
  Formatted string of the conversation history
648
  """
649
  # Group messages into turns
650
- turns = []
651
- messages = []
652
 
653
  # First, filter out keyword extraction messages
654
  for msg in conversation_history:
@@ -682,7 +693,7 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
682
  turns = turns[-num_turns:]
683
 
684
  # Format the turns into a string
685
- formatted_turns = []
686
  for turn in turns:
687
  formatted_turns.extend(
688
  [f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
 
1
+ from __future__ import annotations
2
+
3
  import asyncio
4
  import html
5
  import io
 
11
  from dataclasses import dataclass
12
  from functools import wraps
13
  from hashlib import md5
14
+ from typing import Any, Callable
15
  import xml.etree.ElementTree as ET
16
  import bs4
17
 
 
69
 
70
  @dataclass
71
  class ReasoningResponse:
72
+ reasoning_content: str | None
73
  response_content: str
74
  tag: str
75
 
76
 
77
+ def locate_json_string_body_from_string(content: str) -> str | None:
78
  """Locate the JSON string body from a string"""
79
  try:
80
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
 
111
  raise e from None
112
 
113
 
114
+ def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
115
  """Compute a hash for the given arguments.
116
  Args:
117
  *args: Arguments to hash
 
130
  return hashlib.md5(args_str.encode()).hexdigest()
131
 
132
 
133
+ def compute_mdhash_id(content: str, prefix: str = "") -> str:
134
+ """
135
+ Compute a unique ID for a given content string.
136
+
137
+ The ID is a combination of the given prefix and the MD5 hash of the content string.
138
+ """
139
  return prefix + md5(content.encode()).hexdigest()
140
 
141
 
 
222
  return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
223
 
224
 
225
+ def is_float_regex(value: str) -> bool:
226
  return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
227
 
228
 
229
+ def truncate_list_by_token_size(
230
+ list_data: list[Any], key: Callable[[Any], str], max_token_size: int
231
+ ) -> list[int]:
232
  """Truncate a list of data by token size"""
233
  if max_token_size <= 0:
234
  return []
 
240
  return list_data
241
 
242
 
243
+ def list_of_list_to_csv(data: list[list[str]]) -> str:
244
  output = io.StringIO()
245
  writer = csv.writer(
246
  output,
 
253
  return output.getvalue()
254
 
255
 
256
+ def csv_string_to_list(csv_string: str) -> list[list[str]]:
257
  # Clean the string by removing NUL characters
258
  cleaned_string = csv_string.replace("\0", "")
259
 
 
338
  return None
339
 
340
 
341
+ def process_combine_contexts(hl: str, ll: str):
342
  header = None
343
  list_hl = csv_string_to_list(hl.strip())
344
  list_ll = csv_string_to_list(ll.strip())
 
384
  llm_func=None,
385
  original_prompt=None,
386
  cache_type=None,
387
+ ) -> str | None:
388
  logger.debug(
389
  f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
390
  )
 
488
  return dot_product / (norm1 * norm2)
489
 
490
 
491
+ def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
492
  """Quantize embedding to specified bits"""
493
  # Convert list to numpy array if needed
494
  if isinstance(embedding, list):
 
579
  args_hash: str
580
  content: str
581
  prompt: str
582
+ quantized: np.ndarray | None = None
583
+ min_val: float | None = None
584
+ max_val: float | None = None
585
  mode: str = "default"
586
  cache_type: str = "query"
587
 
 
644
  return False
645
 
646
 
647
+ def get_conversation_turns(
648
+ conversation_history: list[dict[str, Any]], num_turns: int
649
+ ) -> str:
650
  """
651
  Process conversation history to get the specified number of complete turns.
652
 
 
658
  Formatted string of the conversation history
659
  """
660
  # Group messages into turns
661
+ turns: list[list[dict[str, Any]]] = []
662
+ messages: list[dict[str, Any]] = []
663
 
664
  # First, filter out keyword extraction messages
665
  for msg in conversation_history:
 
693
  turns = turns[-num_turns:]
694
 
695
  # Format the turns into a string
696
+ formatted_turns: list[str] = []
697
  for turn in turns:
698
  formatted_turns.extend(
699
  [f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
lightrag_webui/src/components/PropertiesView.tsx CHANGED
@@ -200,7 +200,7 @@ const EdgePropertiesView = ({ edge }: { edge: EdgeType }) => {
200
  <label className="text-md pl-1 font-bold tracking-wide text-teal-600">Relationship</label>
201
  <div className="bg-primary/5 max-h-96 overflow-auto rounded p-1">
202
  <PropertyRow name={'Id'} value={edge.id} />
203
- <PropertyRow name={'Type'} value={edge.type} />
204
  <PropertyRow
205
  name={'Source'}
206
  value={edge.sourceNode ? edge.sourceNode.labels.join(', ') : edge.source}
 
200
  <label className="text-md pl-1 font-bold tracking-wide text-teal-600">Relationship</label>
201
  <div className="bg-primary/5 max-h-96 overflow-auto rounded p-1">
202
  <PropertyRow name={'Id'} value={edge.id} />
203
+ {edge.type && <PropertyRow name={'Type'} value={edge.type} />}
204
  <PropertyRow
205
  name={'Source'}
206
  value={edge.sourceNode ? edge.sourceNode.labels.join(', ') : edge.source}
lightrag_webui/src/hooks/useLightragGraph.tsx CHANGED
@@ -24,7 +24,7 @@ const validateGraph = (graph: RawGraph) => {
24
  }
25
 
26
  for (const edge of graph.edges) {
27
- if (!edge.id || !edge.source || !edge.target || !edge.type || !edge.properties) {
28
  return false
29
  }
30
  }
@@ -88,6 +88,14 @@ const fetchGraph = async (label: string) => {
88
  if (source !== undefined && source !== undefined) {
89
  const sourceNode = rawData.nodes[source]
90
  const targetNode = rawData.nodes[target]
 
 
 
 
 
 
 
 
91
  sourceNode.degree += 1
92
  targetNode.degree += 1
93
  }
@@ -146,7 +154,7 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
146
 
147
  for (const rawEdge of rawGraph?.edges ?? []) {
148
  rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
149
- label: rawEdge.type
150
  })
151
  }
152
 
 
24
  }
25
 
26
  for (const edge of graph.edges) {
27
+ if (!edge.id || !edge.source || !edge.target) {
28
  return false
29
  }
30
  }
 
88
  if (source !== undefined && source !== undefined) {
89
  const sourceNode = rawData.nodes[source]
90
  const targetNode = rawData.nodes[target]
91
+ if (!sourceNode) {
92
+ console.error(`Source node ${edge.source} is undefined`)
93
+ continue
94
+ }
95
+ if (!targetNode) {
96
+ console.error(`Target node ${edge.target} is undefined`)
97
+ continue
98
+ }
99
  sourceNode.degree += 1
100
  targetNode.degree += 1
101
  }
 
154
 
155
  for (const rawEdge of rawGraph?.edges ?? []) {
156
  rawEdge.dynamicId = graph.addDirectedEdge(rawEdge.source, rawEdge.target, {
157
+ label: rawEdge.type || undefined
158
  })
159
  }
160
 
lightrag_webui/src/stores/graph.ts CHANGED
@@ -19,7 +19,7 @@ export type RawEdgeType = {
19
  id: string
20
  source: string
21
  target: string
22
- type: string
23
  properties: Record<string, any>
24
 
25
  dynamicId: string
 
19
  id: string
20
  source: string
21
  target: string
22
+ type?: string
23
  properties: Record<string, any>
24
 
25
  dynamicId: string