yangdx
commited on
Commit
·
fa91f10
1
Parent(s):
016d00c
Add missing await consume
Browse files- lightrag/kg/neo4j_impl.py +129 -119
lightrag/kg/neo4j_impl.py
CHANGED
@@ -64,19 +64,19 @@ class Neo4JStorage(BaseGraphStorage):
|
|
64 |
MAX_CONNECTION_POOL_SIZE = int(
|
65 |
os.environ.get(
|
66 |
"NEO4J_MAX_CONNECTION_POOL_SIZE",
|
67 |
-
config.get("neo4j", "connection_pool_size", fallback=50),
|
68 |
)
|
69 |
)
|
70 |
CONNECTION_TIMEOUT = float(
|
71 |
os.environ.get(
|
72 |
"NEO4J_CONNECTION_TIMEOUT",
|
73 |
-
config.get("neo4j", "connection_timeout", fallback=30.0),
|
74 |
),
|
75 |
)
|
76 |
CONNECTION_ACQUISITION_TIMEOUT = float(
|
77 |
os.environ.get(
|
78 |
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
|
79 |
-
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
|
80 |
),
|
81 |
)
|
82 |
MAX_TRANSACTION_RETRY_TIME = float(
|
@@ -188,23 +188,24 @@ class Neo4JStorage(BaseGraphStorage):
|
|
188 |
|
189 |
async def has_node(self, node_id: str) -> bool:
|
190 |
entity_name_label = await self._ensure_label(node_id)
|
191 |
-
async with self._driver.session(
|
|
|
|
|
192 |
query = (
|
193 |
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
194 |
)
|
195 |
result = await session.run(query)
|
196 |
single_result = await result.single()
|
197 |
await result.consume() # Ensure result is fully consumed
|
198 |
-
logger.debug(
|
199 |
-
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
|
200 |
-
)
|
201 |
return single_result["node_exists"]
|
202 |
|
203 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
204 |
entity_name_label_source = source_node_id.strip('"')
|
205 |
entity_name_label_target = target_node_id.strip('"')
|
206 |
|
207 |
-
async with self._driver.session(
|
|
|
|
|
208 |
query = (
|
209 |
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
210 |
"RETURN COUNT(r) > 0 AS edgeExists"
|
@@ -212,9 +213,6 @@ class Neo4JStorage(BaseGraphStorage):
|
|
212 |
result = await session.run(query)
|
213 |
single_result = await result.single()
|
214 |
await result.consume() # Ensure result is fully consumed
|
215 |
-
logger.debug(
|
216 |
-
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
|
217 |
-
)
|
218 |
return single_result["edgeExists"]
|
219 |
|
220 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
@@ -227,14 +225,20 @@ class Neo4JStorage(BaseGraphStorage):
|
|
227 |
dict: Node properties if found
|
228 |
None: If node not found
|
229 |
"""
|
230 |
-
async with self._driver.session(
|
|
|
|
|
231 |
entity_name_label = await self._ensure_label(node_id)
|
232 |
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
233 |
result = await session.run(query)
|
234 |
-
records = await result.fetch(
|
|
|
|
|
235 |
await result.consume() # Ensure result is fully consumed
|
236 |
if len(records) > 1:
|
237 |
-
logger.warning(
|
|
|
|
|
238 |
if records:
|
239 |
node = records[0]["n"]
|
240 |
node_dict = dict(node)
|
@@ -248,16 +252,18 @@ class Neo4JStorage(BaseGraphStorage):
|
|
248 |
"""Get the degree (number of relationships) of a node with the given label.
|
249 |
If multiple nodes have the same label, returns the degree of the first node.
|
250 |
If no node is found, returns 0.
|
251 |
-
|
252 |
Args:
|
253 |
node_id: The label of the node
|
254 |
-
|
255 |
Returns:
|
256 |
int: The number of relationships the node has, or 0 if no node found
|
257 |
"""
|
258 |
entity_name_label = node_id.strip('"')
|
259 |
|
260 |
-
async with self._driver.session(
|
|
|
|
|
261 |
query = f"""
|
262 |
MATCH (n:`{entity_name_label}`)
|
263 |
OPTIONAL MATCH (n)-[r]-()
|
@@ -266,14 +272,16 @@ class Neo4JStorage(BaseGraphStorage):
|
|
266 |
result = await session.run(query)
|
267 |
records = await result.fetch(100)
|
268 |
await result.consume() # Ensure result is fully consumed
|
269 |
-
|
270 |
if not records:
|
271 |
logger.warning(f"No node found with label '{entity_name_label}'")
|
272 |
return 0
|
273 |
-
|
274 |
if len(records) > 1:
|
275 |
-
logger.warning(
|
276 |
-
|
|
|
|
|
277 |
degree = records[0]["degree"]
|
278 |
logger.debug(
|
279 |
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}"
|
@@ -296,30 +304,6 @@ class Neo4JStorage(BaseGraphStorage):
|
|
296 |
)
|
297 |
return degrees
|
298 |
|
299 |
-
async def check_duplicate_nodes(self) -> list[tuple[str, int]]:
|
300 |
-
"""Find all labels that have multiple nodes
|
301 |
-
|
302 |
-
Returns:
|
303 |
-
list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes
|
304 |
-
"""
|
305 |
-
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
306 |
-
query = """
|
307 |
-
MATCH (n)
|
308 |
-
WITH labels(n) as nodeLabels
|
309 |
-
UNWIND nodeLabels as label
|
310 |
-
WITH label, count(*) as node_count
|
311 |
-
WHERE node_count > 1
|
312 |
-
RETURN label, node_count
|
313 |
-
ORDER BY node_count DESC
|
314 |
-
"""
|
315 |
-
result = await session.run(query)
|
316 |
-
duplicates = []
|
317 |
-
async for record in result:
|
318 |
-
label = record["label"]
|
319 |
-
count = record["node_count"]
|
320 |
-
logger.info(f"Found {count} nodes with label: {label}")
|
321 |
-
duplicates.append((label, count))
|
322 |
-
return duplicates
|
323 |
|
324 |
async def get_edge(
|
325 |
self, source_node_id: str, target_node_id: str
|
@@ -328,64 +312,69 @@ class Neo4JStorage(BaseGraphStorage):
|
|
328 |
entity_name_label_source = source_node_id.strip('"')
|
329 |
entity_name_label_target = target_node_id.strip('"')
|
330 |
|
331 |
-
async with self._driver.session(
|
|
|
|
|
332 |
query = f"""
|
333 |
MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`)
|
334 |
RETURN properties(r) as edge_properties
|
335 |
"""
|
336 |
|
337 |
result = await session.run(query)
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
if records:
|
344 |
-
try:
|
345 |
-
result = dict(records[0]["edge_properties"])
|
346 |
-
logger.debug(f"Result: {result}")
|
347 |
-
# Ensure required keys exist with defaults
|
348 |
-
required_keys = {
|
349 |
-
"weight": 0.0,
|
350 |
-
"source_id": None,
|
351 |
-
"description": None,
|
352 |
-
"keywords": None,
|
353 |
-
}
|
354 |
-
for key, default_value in required_keys.items():
|
355 |
-
if key not in result:
|
356 |
-
result[key] = default_value
|
357 |
-
logger.warning(
|
358 |
-
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
|
359 |
-
f"missing {key}, using default: {default_value}"
|
360 |
-
)
|
361 |
-
|
362 |
-
logger.debug(
|
363 |
-
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
364 |
-
)
|
365 |
-
return result
|
366 |
-
except (KeyError, TypeError, ValueError) as e:
|
367 |
-
logger.error(
|
368 |
-
f"Error processing edge properties between {entity_name_label_source} "
|
369 |
-
f"and {entity_name_label_target}: {str(e)}"
|
370 |
)
|
371 |
-
|
372 |
-
|
373 |
-
"
|
374 |
-
"
|
375 |
-
|
376 |
-
|
377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
|
390 |
except Exception as e:
|
391 |
logger.error(
|
@@ -409,7 +398,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
409 |
query = f"""MATCH (n:`{node_label}`)
|
410 |
OPTIONAL MATCH (n)-[r]-(connected)
|
411 |
RETURN n, r, connected"""
|
412 |
-
async with self._driver.session(
|
|
|
|
|
413 |
results = await session.run(query)
|
414 |
edges = []
|
415 |
try:
|
@@ -429,7 +420,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
429 |
if source_label and target_label:
|
430 |
edges.append((source_label, target_label))
|
431 |
finally:
|
432 |
-
await
|
|
|
|
|
433 |
|
434 |
return edges
|
435 |
|
@@ -461,10 +454,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|
461 |
MERGE (n:`{label}`)
|
462 |
SET n += $properties
|
463 |
"""
|
464 |
-
await tx.run(query, properties=properties)
|
465 |
logger.debug(
|
466 |
f"Upserted node with label '{label}' and properties: {properties}"
|
467 |
)
|
|
|
468 |
|
469 |
try:
|
470 |
async with self._driver.session(database=self._DATABASE) as session:
|
@@ -509,9 +503,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|
509 |
target_exists = await self.has_node(target_label)
|
510 |
|
511 |
if not source_exists:
|
512 |
-
raise ValueError(
|
|
|
|
|
513 |
if not target_exists:
|
514 |
-
raise ValueError(
|
|
|
|
|
515 |
|
516 |
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
517 |
query = f"""
|
@@ -570,7 +568,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
570 |
seen_nodes = set()
|
571 |
seen_edges = set()
|
572 |
|
573 |
-
async with self._driver.session(
|
|
|
|
|
574 |
try:
|
575 |
if label == "*":
|
576 |
main_query = """
|
@@ -728,11 +728,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
728 |
visited_nodes.add(node_id)
|
729 |
|
730 |
# Add node data with label as ID
|
731 |
-
result["nodes"].append(
|
732 |
-
"id": current_label,
|
733 |
-
|
734 |
-
"properties": node
|
735 |
-
})
|
736 |
|
737 |
# Get connected nodes that meet the degree requirement
|
738 |
# Note: We don't need to check a's degree since it's the current node
|
@@ -744,7 +742,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
744 |
WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
|
745 |
RETURN r, b
|
746 |
"""
|
747 |
-
async with self._driver.session(
|
|
|
|
|
748 |
results = await session.run(query, {"min_degree": min_degree})
|
749 |
async for record in results:
|
750 |
# Handle edges
|
@@ -754,19 +754,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|
754 |
b_node = record["b"]
|
755 |
if b_node.labels: # Only process if target node has labels
|
756 |
target_label = list(b_node.labels)[0]
|
757 |
-
result["edges"].append(
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
|
|
|
|
764 |
visited_edges.add(edge_id)
|
765 |
|
766 |
# Continue traversal
|
767 |
await traverse(target_label, current_depth + 1)
|
768 |
else:
|
769 |
-
logger.warning(
|
|
|
|
|
770 |
|
771 |
await traverse(label, 0)
|
772 |
return result
|
@@ -777,7 +781,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
777 |
Returns:
|
778 |
["Person", "Company", ...] # Alphabetically sorted label list
|
779 |
"""
|
780 |
-
async with self._driver.session(
|
|
|
|
|
781 |
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
|
782 |
# query = "CALL db.labels() YIELD label RETURN label"
|
783 |
|
@@ -796,7 +802,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
796 |
async for record in result:
|
797 |
labels.append(record["label"])
|
798 |
finally:
|
799 |
-
await
|
|
|
|
|
800 |
return labels
|
801 |
|
802 |
@retry(
|
@@ -824,8 +832,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
824 |
MATCH (n:`{label}`)
|
825 |
DETACH DELETE n
|
826 |
"""
|
827 |
-
await tx.run(query)
|
828 |
logger.debug(f"Deleted node with label '{label}'")
|
|
|
829 |
|
830 |
try:
|
831 |
async with self._driver.session(database=self._DATABASE) as session:
|
@@ -882,8 +891,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
882 |
MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`)
|
883 |
DELETE r
|
884 |
"""
|
885 |
-
await tx.run(query)
|
886 |
logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'")
|
|
|
887 |
|
888 |
try:
|
889 |
async with self._driver.session(database=self._DATABASE) as session:
|
|
|
64 |
MAX_CONNECTION_POOL_SIZE = int(
|
65 |
os.environ.get(
|
66 |
"NEO4J_MAX_CONNECTION_POOL_SIZE",
|
67 |
+
config.get("neo4j", "connection_pool_size", fallback=50),
|
68 |
)
|
69 |
)
|
70 |
CONNECTION_TIMEOUT = float(
|
71 |
os.environ.get(
|
72 |
"NEO4J_CONNECTION_TIMEOUT",
|
73 |
+
config.get("neo4j", "connection_timeout", fallback=30.0),
|
74 |
),
|
75 |
)
|
76 |
CONNECTION_ACQUISITION_TIMEOUT = float(
|
77 |
os.environ.get(
|
78 |
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
|
79 |
+
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
|
80 |
),
|
81 |
)
|
82 |
MAX_TRANSACTION_RETRY_TIME = float(
|
|
|
188 |
|
189 |
async def has_node(self, node_id: str) -> bool:
|
190 |
entity_name_label = await self._ensure_label(node_id)
|
191 |
+
async with self._driver.session(
|
192 |
+
database=self._DATABASE, default_access_mode="READ"
|
193 |
+
) as session:
|
194 |
query = (
|
195 |
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
196 |
)
|
197 |
result = await session.run(query)
|
198 |
single_result = await result.single()
|
199 |
await result.consume() # Ensure result is fully consumed
|
|
|
|
|
|
|
200 |
return single_result["node_exists"]
|
201 |
|
202 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
203 |
entity_name_label_source = source_node_id.strip('"')
|
204 |
entity_name_label_target = target_node_id.strip('"')
|
205 |
|
206 |
+
async with self._driver.session(
|
207 |
+
database=self._DATABASE, default_access_mode="READ"
|
208 |
+
) as session:
|
209 |
query = (
|
210 |
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
211 |
"RETURN COUNT(r) > 0 AS edgeExists"
|
|
|
213 |
result = await session.run(query)
|
214 |
single_result = await result.single()
|
215 |
await result.consume() # Ensure result is fully consumed
|
|
|
|
|
|
|
216 |
return single_result["edgeExists"]
|
217 |
|
218 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
|
|
225 |
dict: Node properties if found
|
226 |
None: If node not found
|
227 |
"""
|
228 |
+
async with self._driver.session(
|
229 |
+
database=self._DATABASE, default_access_mode="READ"
|
230 |
+
) as session:
|
231 |
entity_name_label = await self._ensure_label(node_id)
|
232 |
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
233 |
result = await session.run(query)
|
234 |
+
records = await result.fetch(
|
235 |
+
2
|
236 |
+
) # Get up to 2 records to check for duplicates
|
237 |
await result.consume() # Ensure result is fully consumed
|
238 |
if len(records) > 1:
|
239 |
+
logger.warning(
|
240 |
+
f"Multiple nodes found with label '{entity_name_label}'. Using first node."
|
241 |
+
)
|
242 |
if records:
|
243 |
node = records[0]["n"]
|
244 |
node_dict = dict(node)
|
|
|
252 |
"""Get the degree (number of relationships) of a node with the given label.
|
253 |
If multiple nodes have the same label, returns the degree of the first node.
|
254 |
If no node is found, returns 0.
|
255 |
+
|
256 |
Args:
|
257 |
node_id: The label of the node
|
258 |
+
|
259 |
Returns:
|
260 |
int: The number of relationships the node has, or 0 if no node found
|
261 |
"""
|
262 |
entity_name_label = node_id.strip('"')
|
263 |
|
264 |
+
async with self._driver.session(
|
265 |
+
database=self._DATABASE, default_access_mode="READ"
|
266 |
+
) as session:
|
267 |
query = f"""
|
268 |
MATCH (n:`{entity_name_label}`)
|
269 |
OPTIONAL MATCH (n)-[r]-()
|
|
|
272 |
result = await session.run(query)
|
273 |
records = await result.fetch(100)
|
274 |
await result.consume() # Ensure result is fully consumed
|
275 |
+
|
276 |
if not records:
|
277 |
logger.warning(f"No node found with label '{entity_name_label}'")
|
278 |
return 0
|
279 |
+
|
280 |
if len(records) > 1:
|
281 |
+
logger.warning(
|
282 |
+
f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree"
|
283 |
+
)
|
284 |
+
|
285 |
degree = records[0]["degree"]
|
286 |
logger.debug(
|
287 |
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}"
|
|
|
304 |
)
|
305 |
return degrees
|
306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
|
308 |
async def get_edge(
|
309 |
self, source_node_id: str, target_node_id: str
|
|
|
312 |
entity_name_label_source = source_node_id.strip('"')
|
313 |
entity_name_label_target = target_node_id.strip('"')
|
314 |
|
315 |
+
async with self._driver.session(
|
316 |
+
database=self._DATABASE, default_access_mode="READ"
|
317 |
+
) as session:
|
318 |
query = f"""
|
319 |
MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`)
|
320 |
RETURN properties(r) as edge_properties
|
321 |
"""
|
322 |
|
323 |
result = await session.run(query)
|
324 |
+
try:
|
325 |
+
records = await result.fetch(2) # Get up to 2 records to check for duplicates
|
326 |
+
if len(records) > 1:
|
327 |
+
logger.warning(
|
328 |
+
f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
)
|
330 |
+
if records:
|
331 |
+
try:
|
332 |
+
result = dict(records[0]["edge_properties"])
|
333 |
+
logger.debug(f"Result: {result}")
|
334 |
+
# Ensure required keys exist with defaults
|
335 |
+
required_keys = {
|
336 |
+
"weight": 0.0,
|
337 |
+
"source_id": None,
|
338 |
+
"description": None,
|
339 |
+
"keywords": None,
|
340 |
+
}
|
341 |
+
for key, default_value in required_keys.items():
|
342 |
+
if key not in result:
|
343 |
+
result[key] = default_value
|
344 |
+
logger.warning(
|
345 |
+
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
|
346 |
+
f"missing {key}, using default: {default_value}"
|
347 |
+
)
|
348 |
|
349 |
+
logger.debug(
|
350 |
+
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
351 |
+
)
|
352 |
+
return result
|
353 |
+
except (KeyError, TypeError, ValueError) as e:
|
354 |
+
logger.error(
|
355 |
+
f"Error processing edge properties between {entity_name_label_source} "
|
356 |
+
f"and {entity_name_label_target}: {str(e)}"
|
357 |
+
)
|
358 |
+
# Return default edge properties on error
|
359 |
+
return {
|
360 |
+
"weight": 0.0,
|
361 |
+
"description": None,
|
362 |
+
"keywords": None,
|
363 |
+
"source_id": None,
|
364 |
+
}
|
365 |
+
|
366 |
+
logger.debug(
|
367 |
+
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
368 |
+
)
|
369 |
+
# Return default edge properties when no edge found
|
370 |
+
return {
|
371 |
+
"weight": 0.0,
|
372 |
+
"description": None,
|
373 |
+
"keywords": None,
|
374 |
+
"source_id": None,
|
375 |
+
}
|
376 |
+
finally:
|
377 |
+
await result.consume() # Ensure result is fully consumed
|
378 |
|
379 |
except Exception as e:
|
380 |
logger.error(
|
|
|
398 |
query = f"""MATCH (n:`{node_label}`)
|
399 |
OPTIONAL MATCH (n)-[r]-(connected)
|
400 |
RETURN n, r, connected"""
|
401 |
+
async with self._driver.session(
|
402 |
+
database=self._DATABASE, default_access_mode="READ"
|
403 |
+
) as session:
|
404 |
results = await session.run(query)
|
405 |
edges = []
|
406 |
try:
|
|
|
420 |
if source_label and target_label:
|
421 |
edges.append((source_label, target_label))
|
422 |
finally:
|
423 |
+
await (
|
424 |
+
results.consume()
|
425 |
+
) # Ensure results are consumed even if processing fails
|
426 |
|
427 |
return edges
|
428 |
|
|
|
454 |
MERGE (n:`{label}`)
|
455 |
SET n += $properties
|
456 |
"""
|
457 |
+
result = await tx.run(query, properties=properties)
|
458 |
logger.debug(
|
459 |
f"Upserted node with label '{label}' and properties: {properties}"
|
460 |
)
|
461 |
+
await result.consume() # Ensure result is fully consumed
|
462 |
|
463 |
try:
|
464 |
async with self._driver.session(database=self._DATABASE) as session:
|
|
|
503 |
target_exists = await self.has_node(target_label)
|
504 |
|
505 |
if not source_exists:
|
506 |
+
raise ValueError(
|
507 |
+
f"Neo4j: source node with label '{source_label}' does not exist"
|
508 |
+
)
|
509 |
if not target_exists:
|
510 |
+
raise ValueError(
|
511 |
+
f"Neo4j: target node with label '{target_label}' does not exist"
|
512 |
+
)
|
513 |
|
514 |
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
515 |
query = f"""
|
|
|
568 |
seen_nodes = set()
|
569 |
seen_edges = set()
|
570 |
|
571 |
+
async with self._driver.session(
|
572 |
+
database=self._DATABASE, default_access_mode="READ"
|
573 |
+
) as session:
|
574 |
try:
|
575 |
if label == "*":
|
576 |
main_query = """
|
|
|
728 |
visited_nodes.add(node_id)
|
729 |
|
730 |
# Add node data with label as ID
|
731 |
+
result["nodes"].append(
|
732 |
+
{"id": current_label, "labels": current_label, "properties": node}
|
733 |
+
)
|
|
|
|
|
734 |
|
735 |
# Get connected nodes that meet the degree requirement
|
736 |
# Note: We don't need to check a's degree since it's the current node
|
|
|
742 |
WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
|
743 |
RETURN r, b
|
744 |
"""
|
745 |
+
async with self._driver.session(
|
746 |
+
database=self._DATABASE, default_access_mode="READ"
|
747 |
+
) as session:
|
748 |
results = await session.run(query, {"min_degree": min_degree})
|
749 |
async for record in results:
|
750 |
# Handle edges
|
|
|
754 |
b_node = record["b"]
|
755 |
if b_node.labels: # Only process if target node has labels
|
756 |
target_label = list(b_node.labels)[0]
|
757 |
+
result["edges"].append(
|
758 |
+
{
|
759 |
+
"id": f"{current_label}_{target_label}",
|
760 |
+
"type": rel.type,
|
761 |
+
"source": current_label,
|
762 |
+
"target": target_label,
|
763 |
+
"properties": dict(rel),
|
764 |
+
}
|
765 |
+
)
|
766 |
visited_edges.add(edge_id)
|
767 |
|
768 |
# Continue traversal
|
769 |
await traverse(target_label, current_depth + 1)
|
770 |
else:
|
771 |
+
logger.warning(
|
772 |
+
f"Skipping edge {edge_id} due to missing labels on target node"
|
773 |
+
)
|
774 |
|
775 |
await traverse(label, 0)
|
776 |
return result
|
|
|
781 |
Returns:
|
782 |
["Person", "Company", ...] # Alphabetically sorted label list
|
783 |
"""
|
784 |
+
async with self._driver.session(
|
785 |
+
database=self._DATABASE, default_access_mode="READ"
|
786 |
+
) as session:
|
787 |
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
|
788 |
# query = "CALL db.labels() YIELD label RETURN label"
|
789 |
|
|
|
802 |
async for record in result:
|
803 |
labels.append(record["label"])
|
804 |
finally:
|
805 |
+
await (
|
806 |
+
result.consume()
|
807 |
+
) # Ensure results are consumed even if processing fails
|
808 |
return labels
|
809 |
|
810 |
@retry(
|
|
|
832 |
MATCH (n:`{label}`)
|
833 |
DETACH DELETE n
|
834 |
"""
|
835 |
+
result = await tx.run(query)
|
836 |
logger.debug(f"Deleted node with label '{label}'")
|
837 |
+
await result.consume() # Ensure result is fully consumed
|
838 |
|
839 |
try:
|
840 |
async with self._driver.session(database=self._DATABASE) as session:
|
|
|
891 |
MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`)
|
892 |
DELETE r
|
893 |
"""
|
894 |
+
result = await tx.run(query)
|
895 |
logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'")
|
896 |
+
await result.consume() # Ensure result is fully consumed
|
897 |
|
898 |
try:
|
899 |
async with self._driver.session(database=self._DATABASE) as session:
|