Pankaj Kaushal commited on
Commit
87da96c
·
1 Parent(s): 07feee7

Enhance Neo4j graph storage with error handling and label validation

Browse files

- Add label existence check and validation methods in Neo4j implementation
- Improve error handling in get_node, get_edge, and upsert methods
- Add default values and logging for missing edge properties
- Ensure consistent label processing across graph storage methods

Files changed (2) hide show
  1. lightrag/kg/neo4j_impl.py +98 -34
  2. lightrag/operate.py +51 -11
lightrag/kg/neo4j_impl.py CHANGED
@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
143
  async def index_done_callback(self):
144
  print("KG successfully indexed.")
145
 
146
- async def has_node(self, node_id: str) -> bool:
147
- entity_name_label = node_id.strip('"')
 
 
 
 
 
 
 
 
 
148
 
 
 
 
 
 
 
 
 
 
149
  async with self._driver.session(database=self._DATABASE) as session:
150
  query = (
151
  f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
174
  return single_result["edgeExists"]
175
 
176
  async def get_node(self, node_id: str) -> Union[dict, None]:
 
 
 
 
 
 
 
 
 
177
  async with self._driver.session(database=self._DATABASE) as session:
178
- entity_name_label = node_id.strip('"')
179
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
180
  result = await session.run(query)
181
  record = await result.single()
@@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage):
226
  async def get_edge(
227
  self, source_node_id: str, target_node_id: str
228
  ) -> Union[dict, None]:
229
- entity_name_label_source = source_node_id.strip('"')
230
- entity_name_label_target = target_node_id.strip('"')
231
- """
232
- Find all edges between nodes of two given labels
233
 
234
  Args:
235
- source_node_label (str): Label of the source nodes
236
- target_node_label (str): Label of the target nodes
237
 
238
  Returns:
239
- list: List of all relationships/edges found
 
240
  """
241
- async with self._driver.session(database=self._DATABASE) as session:
242
- query = f"""
243
- MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
244
- RETURN properties(r) as edge_properties
245
- LIMIT 1
246
- """.format(
247
- entity_name_label_source=entity_name_label_source,
248
- entity_name_label_target=entity_name_label_target,
249
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- result = await session.run(query)
252
- record = await result.single()
253
- if record:
254
- result = dict(record["edge_properties"])
255
  logger.debug(
256
- f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
257
  )
258
- return result
259
- else:
260
- return None
 
 
 
 
 
 
261
 
262
  async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
263
  node_label = source_node_id.strip('"')
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
310
  node_id: The unique identifier for the node (used as label)
311
  node_data: Dictionary of node properties
312
  """
313
- label = node_id.strip('"')
314
  properties = node_data
315
 
316
  async def _do_upsert(tx: AsyncManagedTransaction):
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
338
  neo4jExceptions.ServiceUnavailable,
339
  neo4jExceptions.TransientError,
340
  neo4jExceptions.WriteServiceUnavailable,
 
341
  )
342
  ),
343
  )
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
352
  target_node_id (str): Label of the target node (used as identifier)
353
  edge_data (dict): Dictionary of properties to set on the edge
354
  """
355
- source_node_label = source_node_id.strip('"')
356
- target_node_label = target_node_id.strip('"')
357
  edge_properties = edge_data
358
 
359
  async def _do_upsert_edge(tx: AsyncManagedTransaction):
360
  query = f"""
361
- MATCH (source:`{source_node_label}`)
362
  WITH source
363
- MATCH (target:`{target_node_label}`)
364
  MERGE (source)-[r:DIRECTED]->(target)
365
  SET r += $properties
366
  RETURN r
367
  """
368
- await tx.run(query, properties=edge_properties)
 
369
  logger.debug(
370
- f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
371
  )
372
 
373
  try:
 
143
  async def index_done_callback(self):
144
  print("KG successfully indexed.")
145
 
146
+ async def _label_exists(self, label: str) -> bool:
147
+ """Check if a label exists in the Neo4j database."""
148
+ query = "CALL db.labels() YIELD label RETURN label"
149
+ try:
150
+ async with self._driver.session(database=self._DATABASE) as session:
151
+ result = await session.run(query)
152
+ labels = [record["label"] for record in await result.data()]
153
+ return label in labels
154
+ except Exception as e:
155
+ logger.error(f"Error checking label existence: {e}")
156
+ return False
157
 
158
+ async def _ensure_label(self, label: str) -> str:
159
+ """Ensure a label exists by validating it."""
160
+ clean_label = label.strip('"')
161
+ if not await self._label_exists(clean_label):
162
+ logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
163
+ return clean_label
164
+
165
+ async def has_node(self, node_id: str) -> bool:
166
+ entity_name_label = await self._ensure_label(node_id)
167
  async with self._driver.session(database=self._DATABASE) as session:
168
  query = (
169
  f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
 
192
  return single_result["edgeExists"]
193
 
194
  async def get_node(self, node_id: str) -> Union[dict, None]:
195
+ """Get node by its label identifier.
196
+
197
+ Args:
198
+ node_id: The node label to look up
199
+
200
+ Returns:
201
+ dict: Node properties if found
202
+ None: If node not found
203
+ """
204
  async with self._driver.session(database=self._DATABASE) as session:
205
+ entity_name_label = await self._ensure_label(node_id)
206
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
207
  result = await session.run(query)
208
  record = await result.single()
 
253
  async def get_edge(
254
  self, source_node_id: str, target_node_id: str
255
  ) -> Union[dict, None]:
256
+ """Find edge between two nodes identified by their labels.
 
 
 
257
 
258
  Args:
259
+ source_node_id (str): Label of the source node
260
+ target_node_id (str): Label of the target node
261
 
262
  Returns:
263
+ dict: Edge properties if found, with at least {"weight": 0.0}
264
+ None: If error occurs
265
  """
266
+ try:
267
+ entity_name_label_source = source_node_id.strip('"')
268
+ entity_name_label_target = target_node_id.strip('"')
269
+
270
+ async with self._driver.session(database=self._DATABASE) as session:
271
+ query = f"""
272
+ MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
273
+ RETURN properties(r) as edge_properties
274
+ LIMIT 1
275
+ """.format(
276
+ entity_name_label_source=entity_name_label_source,
277
+ entity_name_label_target=entity_name_label_target,
278
+ )
279
+
280
+ result = await session.run(query)
281
+ record = await result.single()
282
+ if record and "edge_properties" in record:
283
+ try:
284
+ result = dict(record["edge_properties"])
285
+ # Ensure required keys exist with defaults
286
+ required_keys = {
287
+ "weight": 0.0,
288
+ "source_id": None,
289
+ "target_id": None,
290
+ }
291
+ for key, default_value in required_keys.items():
292
+ if key not in result:
293
+ result[key] = default_value
294
+ logger.warning(
295
+ f"Edge between {entity_name_label_source} and {entity_name_label_target} "
296
+ f"missing {key}, using default: {default_value}"
297
+ )
298
+
299
+ logger.debug(
300
+ f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
301
+ )
302
+ return result
303
+ except (KeyError, TypeError, ValueError) as e:
304
+ logger.error(
305
+ f"Error processing edge properties between {entity_name_label_source} "
306
+ f"and {entity_name_label_target}: {str(e)}"
307
+ )
308
+ # Return default edge properties on error
309
+ return {"weight": 0.0, "source_id": None, "target_id": None}
310
 
 
 
 
 
311
  logger.debug(
312
+ f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
313
  )
314
+ # Return default edge properties when no edge found
315
+ return {"weight": 0.0, "source_id": None, "target_id": None}
316
+
317
+ except Exception as e:
318
+ logger.error(
319
+ f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
320
+ )
321
+ # Return default edge properties on error
322
+ return {"weight": 0.0, "source_id": None, "target_id": None}
323
 
324
  async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
325
  node_label = source_node_id.strip('"')
 
372
  node_id: The unique identifier for the node (used as label)
373
  node_data: Dictionary of node properties
374
  """
375
+ label = await self._ensure_label(node_id)
376
  properties = node_data
377
 
378
  async def _do_upsert(tx: AsyncManagedTransaction):
 
400
  neo4jExceptions.ServiceUnavailable,
401
  neo4jExceptions.TransientError,
402
  neo4jExceptions.WriteServiceUnavailable,
403
+ neo4jExceptions.ClientError,
404
  )
405
  ),
406
  )
 
415
  target_node_id (str): Label of the target node (used as identifier)
416
  edge_data (dict): Dictionary of properties to set on the edge
417
  """
418
+ source_label = await self._ensure_label(source_node_id)
419
+ target_label = await self._ensure_label(target_node_id)
420
  edge_properties = edge_data
421
 
422
  async def _do_upsert_edge(tx: AsyncManagedTransaction):
423
  query = f"""
424
+ MATCH (source:`{source_label}`)
425
  WITH source
426
+ MATCH (target:`{target_label}`)
427
  MERGE (source)-[r:DIRECTED]->(target)
428
  SET r += $properties
429
  RETURN r
430
  """
431
+ result = await tx.run(query, properties=edge_properties)
432
+ record = await result.single()
433
  logger.debug(
434
+ f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
435
  )
436
 
437
  try:
lightrag/operate.py CHANGED
@@ -237,25 +237,65 @@ async def _merge_edges_then_upsert(
237
 
238
  if await knowledge_graph_inst.has_edge(src_id, tgt_id):
239
  already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
240
- already_weights.append(already_edge["weight"])
241
- already_source_ids.extend(
242
- split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
243
- )
244
- already_description.append(already_edge["description"])
245
- already_keywords.extend(
246
- split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
247
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
 
249
  weight = sum([dp["weight"] for dp in edges_data] + already_weights)
250
  description = GRAPH_FIELD_SEP.join(
251
- sorted(set([dp["description"] for dp in edges_data] + already_description))
 
 
 
 
 
252
  )
253
  keywords = GRAPH_FIELD_SEP.join(
254
- sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
 
 
 
 
 
255
  )
256
  source_id = GRAPH_FIELD_SEP.join(
257
- set([dp["source_id"] for dp in edges_data] + already_source_ids)
 
 
 
258
  )
 
259
  for need_insert_id in [src_id, tgt_id]:
260
  if not (await knowledge_graph_inst.has_node(need_insert_id)):
261
  await knowledge_graph_inst.upsert_node(
 
237
 
238
  if await knowledge_graph_inst.has_edge(src_id, tgt_id):
239
  already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
240
+ # Handle the case where get_edge returns None or missing fields
241
+ if already_edge:
242
+ # Get weight with default 0.0 if missing
243
+ if "weight" in already_edge:
244
+ already_weights.append(already_edge["weight"])
245
+ else:
246
+ logger.warning(
247
+ f"Edge between {src_id} and {tgt_id} missing weight field"
248
+ )
249
+ already_weights.append(0.0)
250
+
251
+ # Get source_id with empty string default if missing or None
252
+ if "source_id" in already_edge and already_edge["source_id"] is not None:
253
+ already_source_ids.extend(
254
+ split_string_by_multi_markers(
255
+ already_edge["source_id"], [GRAPH_FIELD_SEP]
256
+ )
257
+ )
258
+
259
+ # Get description with empty string default if missing or None
260
+ if (
261
+ "description" in already_edge
262
+ and already_edge["description"] is not None
263
+ ):
264
+ already_description.append(already_edge["description"])
265
+
266
+ # Get keywords with empty string default if missing or None
267
+ if "keywords" in already_edge and already_edge["keywords"] is not None:
268
+ already_keywords.extend(
269
+ split_string_by_multi_markers(
270
+ already_edge["keywords"], [GRAPH_FIELD_SEP]
271
+ )
272
+ )
273
 
274
+ # Process edges_data with None checks
275
  weight = sum([dp["weight"] for dp in edges_data] + already_weights)
276
  description = GRAPH_FIELD_SEP.join(
277
+ sorted(
278
+ set(
279
+ [dp["description"] for dp in edges_data if dp.get("description")]
280
+ + already_description
281
+ )
282
+ )
283
  )
284
  keywords = GRAPH_FIELD_SEP.join(
285
+ sorted(
286
+ set(
287
+ [dp["keywords"] for dp in edges_data if dp.get("keywords")]
288
+ + already_keywords
289
+ )
290
+ )
291
  )
292
  source_id = GRAPH_FIELD_SEP.join(
293
+ set(
294
+ [dp["source_id"] for dp in edges_data if dp.get("source_id")]
295
+ + already_source_ids
296
+ )
297
  )
298
+
299
  for need_insert_id in [src_id, tgt_id]:
300
  if not (await knowledge_graph_inst.has_node(need_insert_id)):
301
  await knowledge_graph_inst.upsert_node(