gzdaniel commited on
Commit
7b4414a
·
2 Parent(s): f951108 05a9e1a

Merge branch 'main' into rerank

Browse files
README-zh.md CHANGED
@@ -30,7 +30,7 @@
30
  <a href="https://github.com/HKUDS/LightRAG/issues/285"><img src="https://img.shields.io/badge/💬微信群-交流-07c160?style=for-the-badge&logo=wechat&logoColor=white&labelColor=1a1a2e"></a>
31
  </p>
32
  <p>
33
- <a href="README_zh.md"><img src="https://img.shields.io/badge/🇨🇳中文版-1a1a2e?style=for-the-badge"></a>
34
  <a href="README.md"><img src="https://img.shields.io/badge/🇺🇸English-1a1a2e?style=for-the-badge"></a>
35
  </p>
36
  </div>
 
30
  <a href="https://github.com/HKUDS/LightRAG/issues/285"><img src="https://img.shields.io/badge/💬微信群-交流-07c160?style=for-the-badge&logo=wechat&logoColor=white&labelColor=1a1a2e"></a>
31
  </p>
32
  <p>
33
+ <a href="README-zh.md"><img src="https://img.shields.io/badge/🇨🇳中文版-1a1a2e?style=for-the-badge"></a>
34
  <a href="README.md"><img src="https://img.shields.io/badge/🇺🇸English-1a1a2e?style=for-the-badge"></a>
35
  </p>
36
  </div>
README.md CHANGED
@@ -870,6 +870,41 @@ rag = LightRAG(
870
 
871
  </details>
872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
  ## Edit Entities and Relations
874
 
875
  LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph.
 
870
 
871
  </details>
872
 
873
+ <details>
874
+ <summary> <b>Using Memgraph for Storage</b> </summary>
875
+
876
+ * Memgraph is a high-performance, in-memory graph database compatible with the Neo4j Bolt protocol.
877
+ * You can run Memgraph locally using Docker for easy testing:
878
+ * See: https://memgraph.com/download
879
+
880
+ ```python
881
+ export MEMGRAPH_URI="bolt://localhost:7687"
882
+
883
+ # Setup logger for LightRAG
884
+ setup_logger("lightrag", level="INFO")
885
+
886
+ # When you launch the project, override the default KG: NetworkX
887
+ # by specifying kg="MemgraphStorage".
888
+
889
+ # Note: Default settings use NetworkX
890
+ # Initialize LightRAG with Memgraph implementation.
891
+ async def initialize_rag():
892
+ rag = LightRAG(
893
+ working_dir=WORKING_DIR,
894
+ llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
895
+ graph_storage="MemgraphStorage", #<-----------override KG default
896
+ )
897
+
898
+ # Initialize database connections
899
+ await rag.initialize_storages()
900
+ # Initialize pipeline status for document processing
901
+ await initialize_pipeline_status()
902
+
903
+ return rag
904
+ ```
905
+
906
+ </details>
907
+
908
  ## Edit Entities and Relations
909
 
910
  LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph.
config.ini.example CHANGED
@@ -21,3 +21,6 @@ password = your_password
21
  database = your_database
22
  workspace = default # 可选,默认为default
23
  max_connections = 12
 
 
 
 
21
  database = your_database
22
  workspace = default # 可选,默认为default
23
  max_connections = 12
24
+
25
+ [memgraph]
26
+ uri = bolt://localhost:7687
env.example CHANGED
@@ -147,13 +147,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
147
  # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage
148
  ### Graph Storage (Recommended for production deployment)
149
  # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
 
150
 
151
  ####################################################################
152
  ### Default workspace for all storage types
153
  ### For the purpose of isolation of data for each LightRAG instance
154
  ### Valid characters: a-z, A-Z, 0-9, and _
155
  ####################################################################
156
- # WORKSPACE=doc—
157
 
158
  ### PostgreSQL Configuration
159
  POSTGRES_HOST=localhost
@@ -192,3 +193,10 @@ QDRANT_URL=http://localhost:6333
192
  ### Redis
193
  REDIS_URI=redis://localhost:6379
194
  # REDIS_WORKSPACE=forced_workspace_name
 
 
 
 
 
 
 
 
147
  # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage
148
  ### Graph Storage (Recommended for production deployment)
149
  # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
150
+ # LIGHTRAG_GRAPH_STORAGE=MemgraphStorage
151
 
152
  ####################################################################
153
  ### Default workspace for all storage types
154
  ### For the purpose of isolation of data for each LightRAG instance
155
  ### Valid characters: a-z, A-Z, 0-9, and _
156
  ####################################################################
157
+ # WORKSPACE=space1
158
 
159
  ### PostgreSQL Configuration
160
  POSTGRES_HOST=localhost
 
193
  ### Redis
194
  REDIS_URI=redis://localhost:6379
195
  # REDIS_WORKSPACE=forced_workspace_name
196
+
197
+ ### Memgraph Configuration
198
+ MEMGRAPH_URI=bolt://localhost:7687
199
+ MEMGRAPH_USERNAME=
200
+ MEMGRAPH_PASSWORD=
201
+ MEMGRAPH_DATABASE=memgraph
202
+ # MEMGRAPH_WORKSPACE=forced_workspace_name
lightrag/kg/__init__.py CHANGED
@@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = {
15
  "Neo4JStorage",
16
  "PGGraphStorage",
17
  "MongoGraphStorage",
 
18
  # "AGEStorage",
19
  # "TiDBGraphStorage",
20
  # "GremlinStorage",
@@ -57,6 +58,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
57
  "NetworkXStorage": [],
58
  "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
59
  "MongoGraphStorage": [],
 
60
  # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
61
  "AGEStorage": [
62
  "AGE_POSTGRES_DB",
@@ -111,6 +113,7 @@ STORAGES = {
111
  "PGDocStatusStorage": ".kg.postgres_impl",
112
  "FaissVectorDBStorage": ".kg.faiss_impl",
113
  "QdrantVectorDBStorage": ".kg.qdrant_impl",
 
114
  }
115
 
116
 
 
15
  "Neo4JStorage",
16
  "PGGraphStorage",
17
  "MongoGraphStorage",
18
+ "MemgraphStorage",
19
  # "AGEStorage",
20
  # "TiDBGraphStorage",
21
  # "GremlinStorage",
 
58
  "NetworkXStorage": [],
59
  "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
60
  "MongoGraphStorage": [],
61
+ "MemgraphStorage": ["MEMGRAPH_URI"],
62
  # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
63
  "AGEStorage": [
64
  "AGE_POSTGRES_DB",
 
113
  "PGDocStatusStorage": ".kg.postgres_impl",
114
  "FaissVectorDBStorage": ".kg.faiss_impl",
115
  "QdrantVectorDBStorage": ".kg.qdrant_impl",
116
+ "MemgraphStorage": ".kg.memgraph_impl",
117
  }
118
 
119
 
lightrag/kg/memgraph_impl.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import final
4
+ import configparser
5
+
6
+ from ..utils import logger
7
+ from ..base import BaseGraphStorage
8
+ from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
9
+ from ..constants import GRAPH_FIELD_SEP
10
+ import pipmaster as pm
11
+
12
+ if not pm.is_installed("neo4j"):
13
+ pm.install("neo4j")
14
+
15
+ from neo4j import (
16
+ AsyncGraphDatabase,
17
+ AsyncManagedTransaction,
18
+ )
19
+
20
+ from dotenv import load_dotenv
21
+
22
+ # use the .env that is inside the current folder
23
+ load_dotenv(dotenv_path=".env", override=False)
24
+
25
+ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
26
+
27
+ config = configparser.ConfigParser()
28
+ config.read("config.ini", "utf-8")
29
+
30
+
31
+ @final
32
+ @dataclass
33
+ class MemgraphStorage(BaseGraphStorage):
34
+ def __init__(self, namespace, global_config, embedding_func, workspace=None):
35
+ memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
36
+ if memgraph_workspace and memgraph_workspace.strip():
37
+ workspace = memgraph_workspace
38
+ super().__init__(
39
+ namespace=namespace,
40
+ workspace=workspace or "",
41
+ global_config=global_config,
42
+ embedding_func=embedding_func,
43
+ )
44
+ self._driver = None
45
+
46
+ def _get_workspace_label(self) -> str:
47
+ """Get workspace label, return 'base' for compatibility when workspace is empty"""
48
+ workspace = getattr(self, "workspace", None)
49
+ return workspace if workspace else "base"
50
+
51
+ async def initialize(self):
52
+ URI = os.environ.get(
53
+ "MEMGRAPH_URI",
54
+ config.get("memgraph", "uri", fallback="bolt://localhost:7687"),
55
+ )
56
+ USERNAME = os.environ.get(
57
+ "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")
58
+ )
59
+ PASSWORD = os.environ.get(
60
+ "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
61
+ )
62
+ DATABASE = os.environ.get(
63
+ "MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")
64
+ )
65
+
66
+ self._driver = AsyncGraphDatabase.driver(
67
+ URI,
68
+ auth=(USERNAME, PASSWORD),
69
+ )
70
+ self._DATABASE = DATABASE
71
+ try:
72
+ async with self._driver.session(database=DATABASE) as session:
73
+ # Create index for base nodes on entity_id if it doesn't exist
74
+ try:
75
+ workspace_label = self._get_workspace_label()
76
+ await session.run(
77
+ f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
78
+ )
79
+ logger.info(
80
+ f"Created index on :{workspace_label}(entity_id) in Memgraph."
81
+ )
82
+ except Exception as e:
83
+ # Index may already exist, which is not an error
84
+ logger.warning(
85
+ f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
86
+ )
87
+ await session.run("RETURN 1")
88
+ logger.info(f"Connected to Memgraph at {URI}")
89
+ except Exception as e:
90
+ logger.error(f"Failed to connect to Memgraph at {URI}: {e}")
91
+ raise
92
+
93
+ async def finalize(self):
94
+ if self._driver is not None:
95
+ await self._driver.close()
96
+ self._driver = None
97
+
98
+ async def __aexit__(self, exc_type, exc, tb):
99
+ await self.finalize()
100
+
101
+ async def index_done_callback(self):
102
+ # Memgraph handles persistence automatically
103
+ pass
104
+
105
+ async def has_node(self, node_id: str) -> bool:
106
+ """
107
+ Check if a node exists in the graph.
108
+
109
+ Args:
110
+ node_id: The ID of the node to check.
111
+
112
+ Returns:
113
+ bool: True if the node exists, False otherwise.
114
+
115
+ Raises:
116
+ Exception: If there is an error checking the node existence.
117
+ """
118
+ if self._driver is None:
119
+ raise RuntimeError(
120
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
121
+ )
122
+ async with self._driver.session(
123
+ database=self._DATABASE, default_access_mode="READ"
124
+ ) as session:
125
+ try:
126
+ workspace_label = self._get_workspace_label()
127
+ query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
128
+ result = await session.run(query, entity_id=node_id)
129
+ single_result = await result.single()
130
+ await result.consume() # Ensure result is fully consumed
131
+ return (
132
+ single_result["node_exists"] if single_result is not None else False
133
+ )
134
+ except Exception as e:
135
+ logger.error(f"Error checking node existence for {node_id}: {str(e)}")
136
+ await result.consume() # Ensure the result is consumed even on error
137
+ raise
138
+
139
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
140
+ """
141
+ Check if an edge exists between two nodes in the graph.
142
+
143
+ Args:
144
+ source_node_id: The ID of the source node.
145
+ target_node_id: The ID of the target node.
146
+
147
+ Returns:
148
+ bool: True if the edge exists, False otherwise.
149
+
150
+ Raises:
151
+ Exception: If there is an error checking the edge existence.
152
+ """
153
+ if self._driver is None:
154
+ raise RuntimeError(
155
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
156
+ )
157
+ async with self._driver.session(
158
+ database=self._DATABASE, default_access_mode="READ"
159
+ ) as session:
160
+ try:
161
+ workspace_label = self._get_workspace_label()
162
+ query = (
163
+ f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
164
+ "RETURN COUNT(r) > 0 AS edgeExists"
165
+ )
166
+ result = await session.run(
167
+ query,
168
+ source_entity_id=source_node_id,
169
+ target_entity_id=target_node_id,
170
+ ) # type: ignore
171
+ single_result = await result.single()
172
+ await result.consume() # Ensure result is fully consumed
173
+ return (
174
+ single_result["edgeExists"] if single_result is not None else False
175
+ )
176
+ except Exception as e:
177
+ logger.error(
178
+ f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
179
+ )
180
+ await result.consume() # Ensure the result is consumed even on error
181
+ raise
182
+
183
+ async def get_node(self, node_id: str) -> dict[str, str] | None:
184
+ """Get node by its label identifier, return only node properties
185
+
186
+ Args:
187
+ node_id: The node label to look up
188
+
189
+ Returns:
190
+ dict: Node properties if found
191
+ None: If node not found
192
+
193
+ Raises:
194
+ Exception: If there is an error executing the query
195
+ """
196
+ if self._driver is None:
197
+ raise RuntimeError(
198
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
199
+ )
200
+ async with self._driver.session(
201
+ database=self._DATABASE, default_access_mode="READ"
202
+ ) as session:
203
+ try:
204
+ workspace_label = self._get_workspace_label()
205
+ query = (
206
+ f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
207
+ )
208
+ result = await session.run(query, entity_id=node_id)
209
+ try:
210
+ records = await result.fetch(
211
+ 2
212
+ ) # Get 2 records for duplication check
213
+
214
+ if len(records) > 1:
215
+ logger.warning(
216
+ f"Multiple nodes found with label '{node_id}'. Using first node."
217
+ )
218
+ if records:
219
+ node = records[0]["n"]
220
+ node_dict = dict(node)
221
+ # Remove workspace label from labels list if it exists
222
+ if "labels" in node_dict:
223
+ node_dict["labels"] = [
224
+ label
225
+ for label in node_dict["labels"]
226
+ if label != workspace_label
227
+ ]
228
+ return node_dict
229
+ return None
230
+ finally:
231
+ await result.consume() # Ensure result is fully consumed
232
+ except Exception as e:
233
+ logger.error(f"Error getting node for {node_id}: {str(e)}")
234
+ raise
235
+
236
+ async def node_degree(self, node_id: str) -> int:
237
+ """Get the degree (number of relationships) of a node with the given label.
238
+ If multiple nodes have the same label, returns the degree of the first node.
239
+ If no node is found, returns 0.
240
+
241
+ Args:
242
+ node_id: The label of the node
243
+
244
+ Returns:
245
+ int: The number of relationships the node has, or 0 if no node found
246
+
247
+ Raises:
248
+ Exception: If there is an error executing the query
249
+ """
250
+ if self._driver is None:
251
+ raise RuntimeError(
252
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
253
+ )
254
+ async with self._driver.session(
255
+ database=self._DATABASE, default_access_mode="READ"
256
+ ) as session:
257
+ try:
258
+ workspace_label = self._get_workspace_label()
259
+ query = f"""
260
+ MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
261
+ OPTIONAL MATCH (n)-[r]-()
262
+ RETURN COUNT(r) AS degree
263
+ """
264
+ result = await session.run(query, entity_id=node_id)
265
+ try:
266
+ record = await result.single()
267
+
268
+ if not record:
269
+ logger.warning(f"No node found with label '{node_id}'")
270
+ return 0
271
+
272
+ degree = record["degree"]
273
+ return degree
274
+ finally:
275
+ await result.consume() # Ensure result is fully consumed
276
+ except Exception as e:
277
+ logger.error(f"Error getting node degree for {node_id}: {str(e)}")
278
+ raise
279
+
280
+ async def get_all_labels(self) -> list[str]:
281
+ """
282
+ Get all existing node labels in the database
283
+ Returns:
284
+ ["Person", "Company", ...] # Alphabetically sorted label list
285
+
286
+ Raises:
287
+ Exception: If there is an error executing the query
288
+ """
289
+ if self._driver is None:
290
+ raise RuntimeError(
291
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
292
+ )
293
+ async with self._driver.session(
294
+ database=self._DATABASE, default_access_mode="READ"
295
+ ) as session:
296
+ try:
297
+ workspace_label = self._get_workspace_label()
298
+ query = f"""
299
+ MATCH (n:`{workspace_label}`)
300
+ WHERE n.entity_id IS NOT NULL
301
+ RETURN DISTINCT n.entity_id AS label
302
+ ORDER BY label
303
+ """
304
+ result = await session.run(query)
305
+ labels = []
306
+ async for record in result:
307
+ labels.append(record["label"])
308
+ await result.consume()
309
+ return labels
310
+ except Exception as e:
311
+ logger.error(f"Error getting all labels: {str(e)}")
312
+ await result.consume() # Ensure the result is consumed even on error
313
+ raise
314
+
315
+ async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
316
+ """Retrieves all edges (relationships) for a particular node identified by its label.
317
+
318
+ Args:
319
+ source_node_id: Label of the node to get edges for
320
+
321
+ Returns:
322
+ list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
323
+ None: If no edges found
324
+
325
+ Raises:
326
+ Exception: If there is an error executing the query
327
+ """
328
+ if self._driver is None:
329
+ raise RuntimeError(
330
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
331
+ )
332
+ try:
333
+ async with self._driver.session(
334
+ database=self._DATABASE, default_access_mode="READ"
335
+ ) as session:
336
+ try:
337
+ workspace_label = self._get_workspace_label()
338
+ query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
339
+ OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
340
+ WHERE connected.entity_id IS NOT NULL
341
+ RETURN n, r, connected"""
342
+ results = await session.run(query, entity_id=source_node_id)
343
+
344
+ edges = []
345
+ async for record in results:
346
+ source_node = record["n"]
347
+ connected_node = record["connected"]
348
+
349
+ # Skip if either node is None
350
+ if not source_node or not connected_node:
351
+ continue
352
+
353
+ source_label = (
354
+ source_node.get("entity_id")
355
+ if source_node.get("entity_id")
356
+ else None
357
+ )
358
+ target_label = (
359
+ connected_node.get("entity_id")
360
+ if connected_node.get("entity_id")
361
+ else None
362
+ )
363
+
364
+ if source_label and target_label:
365
+ edges.append((source_label, target_label))
366
+
367
+ await results.consume() # Ensure results are consumed
368
+ return edges
369
+ except Exception as e:
370
+ logger.error(
371
+ f"Error getting edges for node {source_node_id}: {str(e)}"
372
+ )
373
+ await results.consume() # Ensure results are consumed even on error
374
+ raise
375
+ except Exception as e:
376
+ logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
377
+ raise
378
+
379
+ async def get_edge(
380
+ self, source_node_id: str, target_node_id: str
381
+ ) -> dict[str, str] | None:
382
+ """Get edge properties between two nodes.
383
+
384
+ Args:
385
+ source_node_id: Label of the source node
386
+ target_node_id: Label of the target node
387
+
388
+ Returns:
389
+ dict: Edge properties if found, default properties if not found or on error
390
+
391
+ Raises:
392
+ Exception: If there is an error executing the query
393
+ """
394
+ if self._driver is None:
395
+ raise RuntimeError(
396
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
397
+ )
398
+ async with self._driver.session(
399
+ database=self._DATABASE, default_access_mode="READ"
400
+ ) as session:
401
+ try:
402
+ workspace_label = self._get_workspace_label()
403
+ query = f"""
404
+ MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
405
+ RETURN properties(r) as edge_properties
406
+ """
407
+ result = await session.run(
408
+ query,
409
+ source_entity_id=source_node_id,
410
+ target_entity_id=target_node_id,
411
+ )
412
+ records = await result.fetch(2)
413
+ await result.consume()
414
+ if records:
415
+ edge_result = dict(records[0]["edge_properties"])
416
+ for key, default_value in {
417
+ "weight": 0.0,
418
+ "source_id": None,
419
+ "description": None,
420
+ "keywords": None,
421
+ }.items():
422
+ if key not in edge_result:
423
+ edge_result[key] = default_value
424
+ logger.warning(
425
+ f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
426
+ )
427
+ return edge_result
428
+ return None
429
+ except Exception as e:
430
+ logger.error(
431
+ f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
432
+ )
433
+ await result.consume() # Ensure the result is consumed even on error
434
+ raise
435
+
436
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
437
+ """
438
+ Upsert a node in the Neo4j database.
439
+
440
+ Args:
441
+ node_id: The unique identifier for the node (used as label)
442
+ node_data: Dictionary of node properties
443
+ """
444
+ if self._driver is None:
445
+ raise RuntimeError(
446
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
447
+ )
448
+ properties = node_data
449
+ entity_type = properties["entity_type"]
450
+ if "entity_id" not in properties:
451
+ raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
452
+
453
+ try:
454
+ async with self._driver.session(database=self._DATABASE) as session:
455
+ workspace_label = self._get_workspace_label()
456
+
457
+ async def execute_upsert(tx: AsyncManagedTransaction):
458
+ query = f"""
459
+ MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
460
+ SET n += $properties
461
+ SET n:`{entity_type}`
462
+ """
463
+ result = await tx.run(
464
+ query, entity_id=node_id, properties=properties
465
+ )
466
+ await result.consume() # Ensure result is fully consumed
467
+
468
+ await session.execute_write(execute_upsert)
469
+ except Exception as e:
470
+ logger.error(f"Error during upsert: {str(e)}")
471
+ raise
472
+
473
+ async def upsert_edge(
474
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
475
+ ) -> None:
476
+ """
477
+ Upsert an edge and its properties between two nodes identified by their labels.
478
+ Ensures both source and target nodes exist and are unique before creating the edge.
479
+ Uses entity_id property to uniquely identify nodes.
480
+
481
+ Args:
482
+ source_node_id (str): Label of the source node (used as identifier)
483
+ target_node_id (str): Label of the target node (used as identifier)
484
+ edge_data (dict): Dictionary of properties to set on the edge
485
+
486
+ Raises:
487
+ Exception: If there is an error executing the query
488
+ """
489
+ if self._driver is None:
490
+ raise RuntimeError(
491
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
492
+ )
493
+ try:
494
+ edge_properties = edge_data
495
+ async with self._driver.session(database=self._DATABASE) as session:
496
+
497
+ async def execute_upsert(tx: AsyncManagedTransaction):
498
+ workspace_label = self._get_workspace_label()
499
+ query = f"""
500
+ MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
501
+ WITH source
502
+ MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
503
+ MERGE (source)-[r:DIRECTED]-(target)
504
+ SET r += $properties
505
+ RETURN r, source, target
506
+ """
507
+ result = await tx.run(
508
+ query,
509
+ source_entity_id=source_node_id,
510
+ target_entity_id=target_node_id,
511
+ properties=edge_properties,
512
+ )
513
+ try:
514
+ await result.fetch(2)
515
+ finally:
516
+ await result.consume() # Ensure result is consumed
517
+
518
+ await session.execute_write(execute_upsert)
519
+ except Exception as e:
520
+ logger.error(f"Error during edge upsert: {str(e)}")
521
+ raise
522
+
523
+ async def delete_node(self, node_id: str) -> None:
524
+ """Delete a node with the specified label
525
+
526
+ Args:
527
+ node_id: The label of the node to delete
528
+
529
+ Raises:
530
+ Exception: If there is an error executing the query
531
+ """
532
+ if self._driver is None:
533
+ raise RuntimeError(
534
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
535
+ )
536
+
537
+ async def _do_delete(tx: AsyncManagedTransaction):
538
+ workspace_label = self._get_workspace_label()
539
+ query = f"""
540
+ MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
541
+ DETACH DELETE n
542
+ """
543
+ result = await tx.run(query, entity_id=node_id)
544
+ logger.debug(f"Deleted node with label {node_id}")
545
+ await result.consume()
546
+
547
+ try:
548
+ async with self._driver.session(database=self._DATABASE) as session:
549
+ await session.execute_write(_do_delete)
550
+ except Exception as e:
551
+ logger.error(f"Error during node deletion: {str(e)}")
552
+ raise
553
+
554
+ async def remove_nodes(self, nodes: list[str]):
555
+ """Delete multiple nodes
556
+
557
+ Args:
558
+ nodes: List of node labels to be deleted
559
+ """
560
+ if self._driver is None:
561
+ raise RuntimeError(
562
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
563
+ )
564
+ for node in nodes:
565
+ await self.delete_node(node)
566
+
567
+ async def remove_edges(self, edges: list[tuple[str, str]]):
568
+ """Delete multiple edges
569
+
570
+ Args:
571
+ edges: List of edges to be deleted, each edge is a (source, target) tuple
572
+
573
+ Raises:
574
+ Exception: If there is an error executing the query
575
+ """
576
+ if self._driver is None:
577
+ raise RuntimeError(
578
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
579
+ )
580
+ for source, target in edges:
581
+
582
+ async def _do_delete_edge(tx: AsyncManagedTransaction):
583
+ workspace_label = self._get_workspace_label()
584
+ query = f"""
585
+ MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
586
+ DELETE r
587
+ """
588
+ result = await tx.run(
589
+ query, source_entity_id=source, target_entity_id=target
590
+ )
591
+ logger.debug(f"Deleted edge from '{source}' to '{target}'")
592
+ await result.consume() # Ensure result is fully consumed
593
+
594
+ try:
595
+ async with self._driver.session(database=self._DATABASE) as session:
596
+ await session.execute_write(_do_delete_edge)
597
+ except Exception as e:
598
+ logger.error(f"Error during edge deletion: {str(e)}")
599
+ raise
600
+
601
+ async def drop(self) -> dict[str, str]:
602
+ """Drop all data from the current workspace and clean up resources
603
+
604
+ This method will delete all nodes and relationships in the Memgraph database.
605
+
606
+ Returns:
607
+ dict[str, str]: Operation status and message
608
+ - On success: {"status": "success", "message": "data dropped"}
609
+ - On failure: {"status": "error", "message": "<error details>"}
610
+
611
+ Raises:
612
+ Exception: If there is an error executing the query
613
+ """
614
+ if self._driver is None:
615
+ raise RuntimeError(
616
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
617
+ )
618
+ try:
619
+ async with self._driver.session(database=self._DATABASE) as session:
620
+ workspace_label = self._get_workspace_label()
621
+ query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
622
+ result = await session.run(query)
623
+ await result.consume()
624
+ logger.info(
625
+ f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
626
+ )
627
+ return {"status": "success", "message": "workspace data dropped"}
628
+ except Exception as e:
629
+ logger.error(
630
+ f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
631
+ )
632
+ return {"status": "error", "message": str(e)}
633
+
634
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
635
+ """Get the total degree (sum of relationships) of two nodes.
636
+
637
+ Args:
638
+ src_id: Label of the source node
639
+ tgt_id: Label of the target node
640
+
641
+ Returns:
642
+ int: Sum of the degrees of both nodes
643
+ """
644
+ if self._driver is None:
645
+ raise RuntimeError(
646
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
647
+ )
648
+ src_degree = await self.node_degree(src_id)
649
+ trg_degree = await self.node_degree(tgt_id)
650
+
651
+ # Convert None to 0 for addition
652
+ src_degree = 0 if src_degree is None else src_degree
653
+ trg_degree = 0 if trg_degree is None else trg_degree
654
+
655
+ degrees = int(src_degree) + int(trg_degree)
656
+ return degrees
657
+
658
+ async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
659
+ """Get all nodes that are associated with the given chunk_ids.
660
+
661
+ Args:
662
+ chunk_ids: List of chunk IDs to find associated nodes for
663
+
664
+ Returns:
665
+ list[dict]: A list of nodes, where each node is a dictionary of its properties.
666
+ An empty list if no matching nodes are found.
667
+ """
668
+ if self._driver is None:
669
+ raise RuntimeError(
670
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
671
+ )
672
+ workspace_label = self._get_workspace_label()
673
+ async with self._driver.session(
674
+ database=self._DATABASE, default_access_mode="READ"
675
+ ) as session:
676
+ query = f"""
677
+ UNWIND $chunk_ids AS chunk_id
678
+ MATCH (n:`{workspace_label}`)
679
+ WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
680
+ RETURN DISTINCT n
681
+ """
682
+ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
683
+ nodes = []
684
+ async for record in result:
685
+ node = record["n"]
686
+ node_dict = dict(node)
687
+ node_dict["id"] = node_dict.get("entity_id")
688
+ nodes.append(node_dict)
689
+ await result.consume()
690
+ return nodes
691
+
692
+ async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
693
+ """Get all edges that are associated with the given chunk_ids.
694
+
695
+ Args:
696
+ chunk_ids: List of chunk IDs to find associated edges for
697
+
698
+ Returns:
699
+ list[dict]: A list of edges, where each edge is a dictionary of its properties.
700
+ An empty list if no matching edges are found.
701
+ """
702
+ if self._driver is None:
703
+ raise RuntimeError(
704
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
705
+ )
706
+ workspace_label = self._get_workspace_label()
707
+ async with self._driver.session(
708
+ database=self._DATABASE, default_access_mode="READ"
709
+ ) as session:
710
+ query = f"""
711
+ UNWIND $chunk_ids AS chunk_id
712
+ MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
713
+ WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
714
+ WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
715
+ // Ensure we only return each unique edge once by ordering the source and target
716
+ WITH a, b, r,
717
+ CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
718
+ CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
719
+ RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
720
+ """
721
+ result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
722
+ edges = []
723
+ async for record in result:
724
+ edge_properties = record["properties"]
725
+ edge_properties["source"] = record["source"]
726
+ edge_properties["target"] = record["target"]
727
+ edges.append(edge_properties)
728
+ await result.consume()
729
+ return edges
730
+
731
+ async def get_knowledge_graph(
732
+ self,
733
+ node_label: str,
734
+ max_depth: int = 3,
735
+ max_nodes: int = MAX_GRAPH_NODES,
736
+ ) -> KnowledgeGraph:
737
+ """
738
+ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
739
+
740
+ Args:
741
+ node_label: Label of the starting node, * means all nodes
742
+ max_depth: Maximum depth of the subgraph, Defaults to 3
743
+ max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
744
+
745
+ Returns:
746
+ KnowledgeGraph object containing nodes and edges, with an is_truncated flag
747
+ indicating whether the graph was truncated due to max_nodes limit
748
+
749
+ Raises:
750
+ Exception: If there is an error executing the query
751
+ """
752
+ if self._driver is None:
753
+ raise RuntimeError(
754
+ "Memgraph driver is not initialized. Call 'await initialize()' first."
755
+ )
756
+
757
+ result = KnowledgeGraph()
758
+ seen_nodes = set()
759
+ seen_edges = set()
760
+ workspace_label = self._get_workspace_label()
761
+ async with self._driver.session(
762
+ database=self._DATABASE, default_access_mode="READ"
763
+ ) as session:
764
+ try:
765
+ if node_label == "*":
766
+ # First check if database has any nodes
767
+ count_query = "MATCH (n) RETURN count(n) as total"
768
+ count_result = None
769
+ total_count = 0
770
+ try:
771
+ count_result = await session.run(count_query)
772
+ count_record = await count_result.single()
773
+ if count_record:
774
+ total_count = count_record["total"]
775
+ if total_count == 0:
776
+ logger.debug("No nodes found in database")
777
+ return result
778
+ if total_count > max_nodes:
779
+ result.is_truncated = True
780
+ logger.info(
781
+ f"Graph truncated: {total_count} nodes found, limited to {max_nodes}"
782
+ )
783
+ finally:
784
+ if count_result:
785
+ await count_result.consume()
786
+
787
+ # Run the main query to get nodes with highest degree
788
+ main_query = f"""
789
+ MATCH (n:`{workspace_label}`)
790
+ OPTIONAL MATCH (n)-[r]-()
791
+ WITH n, COALESCE(count(r), 0) AS degree
792
+ ORDER BY degree DESC
793
+ LIMIT $max_nodes
794
+ WITH collect(n) AS kept_nodes
795
+ MATCH (a)-[r]-(b)
796
+ WHERE a IN kept_nodes AND b IN kept_nodes
797
+ RETURN [node IN kept_nodes | {{node: node}}] AS node_info,
798
+ collect(DISTINCT r) AS relationships
799
+ """
800
+ result_set = None
801
+ try:
802
+ result_set = await session.run(
803
+ main_query, {"max_nodes": max_nodes}
804
+ )
805
+ record = await result_set.single()
806
+ if not record:
807
+ logger.debug("No record returned from main query")
808
+ return result
809
+ finally:
810
+ if result_set:
811
+ await result_set.consume()
812
+
813
+ else:
814
+ bfs_query = f"""
815
+ MATCH (start:`{workspace_label}`)
816
+ WHERE start.entity_id = $entity_id
817
+ WITH start
818
+ CALL {{
819
+ WITH start
820
+ MATCH path = (start)-[*0..{max_depth}]-(node)
821
+ WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
822
+ UNWIND path_nodes AS n
823
+ WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
824
+ WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
825
+ RETURN all_nodes, all_rels
826
+ }}
827
+ WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
828
+ WITH
829
+ CASE
830
+ WHEN total_nodes <= {max_nodes} THEN nodes
831
+ ELSE nodes[0..{max_nodes}]
832
+ END AS limited_nodes,
833
+ relationships,
834
+ total_nodes,
835
+ total_nodes > {max_nodes} AS is_truncated
836
+ RETURN
837
+ [node IN limited_nodes | {{node: node}}] AS node_info,
838
+ relationships,
839
+ total_nodes,
840
+ is_truncated
841
+ """
842
+ result_set = None
843
+ try:
844
+ result_set = await session.run(
845
+ bfs_query,
846
+ {
847
+ "entity_id": node_label,
848
+ },
849
+ )
850
+ record = await result_set.single()
851
+ if not record:
852
+ logger.debug(f"No nodes found for entity_id: {node_label}")
853
+ return result
854
+
855
+ # Check if the query indicates truncation
856
+ if "is_truncated" in record and record["is_truncated"]:
857
+ result.is_truncated = True
858
+ logger.info(
859
+ f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
860
+ )
861
+
862
+ finally:
863
+ if result_set:
864
+ await result_set.consume()
865
+
866
+ # Process the record if it exists
867
+ if record and record["node_info"]:
868
+ for node_info in record["node_info"]:
869
+ node = node_info["node"]
870
+ node_id = node.id
871
+ if node_id not in seen_nodes:
872
+ seen_nodes.add(node_id)
873
+ result.nodes.append(
874
+ KnowledgeGraphNode(
875
+ id=f"{node_id}",
876
+ labels=[node.get("entity_id")],
877
+ properties=dict(node),
878
+ )
879
+ )
880
+
881
+ for rel in record["relationships"]:
882
+ edge_id = rel.id
883
+ if edge_id not in seen_edges:
884
+ seen_edges.add(edge_id)
885
+ start = rel.start_node
886
+ end = rel.end_node
887
+ result.edges.append(
888
+ KnowledgeGraphEdge(
889
+ id=f"{edge_id}",
890
+ type=rel.type,
891
+ source=f"{start.id}",
892
+ target=f"{end.id}",
893
+ properties=dict(rel),
894
+ )
895
+ )
896
+
897
+ logger.info(
898
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
899
+ )
900
+
901
+ except Exception as e:
902
+ logger.error(f"Error getting knowledge graph: {str(e)}")
903
+ # Return empty but properly initialized KnowledgeGraph on error
904
+ return KnowledgeGraph()
905
+
906
+ return result
lightrag/llm/openai.py CHANGED
@@ -210,9 +210,18 @@ async def openai_complete_if_cache(
210
  async def inner():
211
  # Track if we've started iterating
212
  iteration_started = False
 
 
213
  try:
214
  iteration_started = True
215
  async for chunk in response:
 
 
 
 
 
 
 
216
  # Check if choices exists and is not empty
217
  if not hasattr(chunk, "choices") or not chunk.choices:
218
  logger.warning(f"Received chunk without choices: {chunk}")
@@ -222,16 +231,31 @@ async def openai_complete_if_cache(
222
  if not hasattr(chunk.choices[0], "delta") or not hasattr(
223
  chunk.choices[0].delta, "content"
224
  ):
225
- logger.warning(
226
- f"Received chunk without delta content: {chunk.choices[0]}"
227
- )
228
  continue
 
229
  content = chunk.choices[0].delta.content
230
  if content is None:
231
  continue
232
  if r"\u" in content:
233
  content = safe_unicode_decode(content.encode("utf-8"))
 
234
  yield content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  except Exception as e:
236
  logger.error(f"Error in stream response: {str(e)}")
237
  # Try to clean up resources if possible
 
210
  async def inner():
211
  # Track if we've started iterating
212
  iteration_started = False
213
+ final_chunk_usage = None
214
+
215
  try:
216
  iteration_started = True
217
  async for chunk in response:
218
+ # Check if this chunk has usage information (final chunk)
219
+ if hasattr(chunk, "usage") and chunk.usage:
220
+ final_chunk_usage = chunk.usage
221
+ logger.debug(
222
+ f"Received usage info in streaming chunk: {chunk.usage}"
223
+ )
224
+
225
  # Check if choices exists and is not empty
226
  if not hasattr(chunk, "choices") or not chunk.choices:
227
  logger.warning(f"Received chunk without choices: {chunk}")
 
231
  if not hasattr(chunk.choices[0], "delta") or not hasattr(
232
  chunk.choices[0].delta, "content"
233
  ):
234
+ # This might be the final chunk, continue to check for usage
 
 
235
  continue
236
+
237
  content = chunk.choices[0].delta.content
238
  if content is None:
239
  continue
240
  if r"\u" in content:
241
  content = safe_unicode_decode(content.encode("utf-8"))
242
+
243
  yield content
244
+
245
+ # After streaming is complete, track token usage
246
+ if token_tracker and final_chunk_usage:
247
+ # Use actual usage from the API
248
+ token_counts = {
249
+ "prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0),
250
+ "completion_tokens": getattr(
251
+ final_chunk_usage, "completion_tokens", 0
252
+ ),
253
+ "total_tokens": getattr(final_chunk_usage, "total_tokens", 0),
254
+ }
255
+ token_tracker.add_usage(token_counts)
256
+ logger.debug(f"Streaming token usage (from API): {token_counts}")
257
+ elif token_tracker:
258
+ logger.debug("No usage information available in streaming response")
259
  except Exception as e:
260
  logger.error(f"Error in stream response: {str(e)}")
261
  # Try to clean up resources if possible
lightrag/operate.py CHANGED
@@ -26,6 +26,7 @@ from .utils import (
26
  get_conversation_turns,
27
  use_llm_func_with_cache,
28
  update_chunk_cache_list,
 
29
  )
30
  from .base import (
31
  BaseGraphStorage,
@@ -1704,7 +1705,8 @@ async def extract_keywords_only(
1704
  result = await use_model_func(kw_prompt, keyword_extraction=True)
1705
 
1706
  # 6. Parse out JSON from the LLM response
1707
- match = re.search(r"\{.*\}", result, re.DOTALL)
 
1708
  if not match:
1709
  logger.error("No JSON-like structure found in the LLM respond.")
1710
  return [], []
 
26
  get_conversation_turns,
27
  use_llm_func_with_cache,
28
  update_chunk_cache_list,
29
+ remove_think_tags,
30
  )
31
  from .base import (
32
  BaseGraphStorage,
 
1705
  result = await use_model_func(kw_prompt, keyword_extraction=True)
1706
 
1707
  # 6. Parse out JSON from the LLM response
1708
+ result = remove_think_tags(result)
1709
+ match = re.search(r"\{.*?\}", result, re.DOTALL)
1710
  if not match:
1711
  logger.error("No JSON-like structure found in the LLM respond.")
1712
  return [], []
lightrag/utils.py CHANGED
@@ -1465,6 +1465,11 @@ async def update_chunk_cache_list(
1465
  )
1466
 
1467
 
 
 
 
 
 
1468
  async def use_llm_func_with_cache(
1469
  input_text: str,
1470
  use_llm_func: callable,
@@ -1531,6 +1536,7 @@ async def use_llm_func_with_cache(
1531
  kwargs["max_tokens"] = max_tokens
1532
 
1533
  res: str = await use_llm_func(input_text, **kwargs)
 
1534
 
1535
  if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"):
1536
  await save_to_cache(
@@ -1557,8 +1563,9 @@ async def use_llm_func_with_cache(
1557
  if max_tokens is not None:
1558
  kwargs["max_tokens"] = max_tokens
1559
 
1560
- logger.info(f"Call LLM function with query text lenght: {len(input_text)}")
1561
- return await use_llm_func(input_text, **kwargs)
 
1562
 
1563
 
1564
  def get_content_summary(content: str, max_length: int = 250) -> str:
 
1465
  )
1466
 
1467
 
1468
+ def remove_think_tags(text: str) -> str:
1469
+ """Remove <think> tags from the text"""
1470
+ return re.sub(r"^(<think>.*?</think>|<think>)", "", text, flags=re.DOTALL).strip()
1471
+
1472
+
1473
  async def use_llm_func_with_cache(
1474
  input_text: str,
1475
  use_llm_func: callable,
 
1536
  kwargs["max_tokens"] = max_tokens
1537
 
1538
  res: str = await use_llm_func(input_text, **kwargs)
1539
+ res = remove_think_tags(res)
1540
 
1541
  if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"):
1542
  await save_to_cache(
 
1563
  if max_tokens is not None:
1564
  kwargs["max_tokens"] = max_tokens
1565
 
1566
+ logger.info(f"Call LLM function with query text length: {len(input_text)}")
1567
+ res = await use_llm_func(input_text, **kwargs)
1568
+ return remove_think_tags(res)
1569
 
1570
 
1571
  def get_content_summary(content: str, max_length: int = 250) -> str:
tests/test_graph_storage.py CHANGED
@@ -10,6 +10,7 @@
10
  - Neo4JStorage
11
  - MongoDBStorage
12
  - PGGraphStorage
 
13
  """
14
 
15
  import asyncio
 
10
  - Neo4JStorage
11
  - MongoDBStorage
12
  - PGGraphStorage
13
+ - MemgraphStorage
14
  """
15
 
16
  import asyncio