yangdx
commited on
Commit
·
706f457
1
Parent(s):
1fbf326
Refactor Neo4JStorage to use entity_id for node identification, use entity_type for node label
Browse files- lightrag/kg/neo4j_impl.py +95 -192
lightrag/kg/neo4j_impl.py
CHANGED
|
@@ -176,23 +176,6 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 176 |
# Noe4J handles persistence automatically
|
| 177 |
pass
|
| 178 |
|
| 179 |
-
def _ensure_label(self, label: str) -> str:
|
| 180 |
-
"""Ensure a label is valid
|
| 181 |
-
|
| 182 |
-
Args:
|
| 183 |
-
label: The label to validate
|
| 184 |
-
|
| 185 |
-
Returns:
|
| 186 |
-
str: The cleaned label
|
| 187 |
-
|
| 188 |
-
Raises:
|
| 189 |
-
ValueError: If label is empty after cleaning
|
| 190 |
-
"""
|
| 191 |
-
clean_label = label.strip('"')
|
| 192 |
-
if not clean_label:
|
| 193 |
-
raise ValueError("Neo4j: Label cannot be empty")
|
| 194 |
-
return clean_label
|
| 195 |
-
|
| 196 |
async def has_node(self, node_id: str) -> bool:
|
| 197 |
"""
|
| 198 |
Check if a node with the given label exists in the database
|
|
@@ -207,19 +190,18 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 207 |
ValueError: If node_id is invalid
|
| 208 |
Exception: If there is an error executing the query
|
| 209 |
"""
|
| 210 |
-
entity_name_label = self._ensure_label(node_id)
|
| 211 |
async with self._driver.session(
|
| 212 |
database=self._DATABASE, default_access_mode="READ"
|
| 213 |
) as session:
|
| 214 |
try:
|
| 215 |
-
query =
|
| 216 |
-
result = await session.run(query)
|
| 217 |
single_result = await result.single()
|
| 218 |
await result.consume() # Ensure result is fully consumed
|
| 219 |
return single_result["node_exists"]
|
| 220 |
except Exception as e:
|
| 221 |
logger.error(
|
| 222 |
-
f"Error checking node existence for {
|
| 223 |
)
|
| 224 |
await result.consume() # Ensure results are consumed even on error
|
| 225 |
raise
|
|
@@ -239,24 +221,21 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 239 |
ValueError: If either node_id is invalid
|
| 240 |
Exception: If there is an error executing the query
|
| 241 |
"""
|
| 242 |
-
entity_name_label_source = self._ensure_label(source_node_id)
|
| 243 |
-
entity_name_label_target = self._ensure_label(target_node_id)
|
| 244 |
-
|
| 245 |
async with self._driver.session(
|
| 246 |
database=self._DATABASE, default_access_mode="READ"
|
| 247 |
) as session:
|
| 248 |
try:
|
| 249 |
query = (
|
| 250 |
-
|
| 251 |
"RETURN COUNT(r) > 0 AS edgeExists"
|
| 252 |
)
|
| 253 |
-
result = await session.run(query)
|
| 254 |
single_result = await result.single()
|
| 255 |
await result.consume() # Ensure result is fully consumed
|
| 256 |
return single_result["edgeExists"]
|
| 257 |
except Exception as e:
|
| 258 |
logger.error(
|
| 259 |
-
f"Error checking edge existence between {
|
| 260 |
)
|
| 261 |
await result.consume() # Ensure results are consumed even on error
|
| 262 |
raise
|
|
@@ -275,13 +254,12 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 275 |
ValueError: If node_id is invalid
|
| 276 |
Exception: If there is an error executing the query
|
| 277 |
"""
|
| 278 |
-
entity_name_label = self._ensure_label(node_id)
|
| 279 |
async with self._driver.session(
|
| 280 |
database=self._DATABASE, default_access_mode="READ"
|
| 281 |
) as session:
|
| 282 |
try:
|
| 283 |
-
query =
|
| 284 |
-
result = await session.run(query, entity_id=
|
| 285 |
try:
|
| 286 |
records = await result.fetch(
|
| 287 |
2
|
|
@@ -289,20 +267,21 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 289 |
|
| 290 |
if len(records) > 1:
|
| 291 |
logger.warning(
|
| 292 |
-
f"Multiple nodes found with label '{
|
| 293 |
)
|
| 294 |
if records:
|
| 295 |
node = records[0]["n"]
|
| 296 |
node_dict = dict(node)
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
| 300 |
return node_dict
|
| 301 |
return None
|
| 302 |
finally:
|
| 303 |
await result.consume() # Ensure result is fully consumed
|
| 304 |
except Exception as e:
|
| 305 |
-
logger.error(f"Error getting node for {
|
| 306 |
raise
|
| 307 |
|
| 308 |
async def node_degree(self, node_id: str) -> int:
|
|
@@ -320,42 +299,33 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 320 |
ValueError: If node_id is invalid
|
| 321 |
Exception: If there is an error executing the query
|
| 322 |
"""
|
| 323 |
-
entity_name_label = self._ensure_label(node_id)
|
| 324 |
-
|
| 325 |
async with self._driver.session(
|
| 326 |
database=self._DATABASE, default_access_mode="READ"
|
| 327 |
) as session:
|
| 328 |
try:
|
| 329 |
-
query =
|
| 330 |
-
MATCH (n
|
| 331 |
OPTIONAL MATCH (n)-[r]-()
|
| 332 |
-
RETURN
|
| 333 |
"""
|
| 334 |
-
result = await session.run(query)
|
| 335 |
try:
|
| 336 |
-
|
| 337 |
|
| 338 |
-
if not
|
| 339 |
logger.warning(
|
| 340 |
-
f"No node found with label '{
|
| 341 |
)
|
| 342 |
return 0
|
| 343 |
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree"
|
| 347 |
-
)
|
| 348 |
-
|
| 349 |
-
degree = records[0]["degree"]
|
| 350 |
-
logger.debug(
|
| 351 |
-
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}"
|
| 352 |
-
)
|
| 353 |
return degree
|
| 354 |
finally:
|
| 355 |
await result.consume() # Ensure result is fully consumed
|
| 356 |
except Exception as e:
|
| 357 |
logger.error(
|
| 358 |
-
f"Error getting node degree for {
|
| 359 |
)
|
| 360 |
raise
|
| 361 |
|
|
@@ -369,11 +339,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 369 |
Returns:
|
| 370 |
int: Sum of the degrees of both nodes
|
| 371 |
"""
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
src_degree = await self.node_degree(entity_name_label_source)
|
| 376 |
-
trg_degree = await self.node_degree(entity_name_label_target)
|
| 377 |
|
| 378 |
# Convert None to 0 for addition
|
| 379 |
src_degree = 0 if src_degree is None else src_degree
|
|
@@ -399,24 +366,20 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 399 |
Exception: If there is an error executing the query
|
| 400 |
"""
|
| 401 |
try:
|
| 402 |
-
entity_name_label_source = self._ensure_label(source_node_id)
|
| 403 |
-
entity_name_label_target = self._ensure_label(target_node_id)
|
| 404 |
-
|
| 405 |
async with self._driver.session(
|
| 406 |
database=self._DATABASE, default_access_mode="READ"
|
| 407 |
) as session:
|
| 408 |
-
query =
|
| 409 |
-
MATCH (start
|
| 410 |
RETURN properties(r) as edge_properties
|
| 411 |
"""
|
| 412 |
-
|
| 413 |
-
result = await session.run(query)
|
| 414 |
try:
|
| 415 |
records = await result.fetch(2)
|
| 416 |
|
| 417 |
if len(records) > 1:
|
| 418 |
logger.warning(
|
| 419 |
-
f"Multiple edges found between '{
|
| 420 |
)
|
| 421 |
if records:
|
| 422 |
try:
|
|
@@ -433,7 +396,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 433 |
if key not in edge_result:
|
| 434 |
edge_result[key] = default_value
|
| 435 |
logger.warning(
|
| 436 |
-
f"Edge between {
|
| 437 |
f"missing {key}, using default: {default_value}"
|
| 438 |
)
|
| 439 |
|
|
@@ -443,8 +406,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 443 |
return edge_result
|
| 444 |
except (KeyError, TypeError, ValueError) as e:
|
| 445 |
logger.error(
|
| 446 |
-
f"Error processing edge properties between {
|
| 447 |
-
f"and {
|
| 448 |
)
|
| 449 |
# Return default edge properties on error
|
| 450 |
return {
|
|
@@ -455,7 +418,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 455 |
}
|
| 456 |
|
| 457 |
logger.debug(
|
| 458 |
-
f"{inspect.currentframe().f_code.co_name}: No edge found between {
|
| 459 |
)
|
| 460 |
# Return default edge properties when no edge found
|
| 461 |
return {
|
|
@@ -488,30 +451,30 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 488 |
Exception: If there is an error executing the query
|
| 489 |
"""
|
| 490 |
try:
|
| 491 |
-
node_label = self._ensure_label(source_node_id)
|
| 492 |
-
|
| 493 |
-
query = f"""MATCH (n:`{node_label}`)
|
| 494 |
-
OPTIONAL MATCH (n)-[r]-(connected)
|
| 495 |
-
RETURN n, r, connected"""
|
| 496 |
-
|
| 497 |
async with self._driver.session(
|
| 498 |
database=self._DATABASE, default_access_mode="READ"
|
| 499 |
) as session:
|
| 500 |
try:
|
| 501 |
-
|
| 502 |
-
|
|
|
|
|
|
|
|
|
|
| 503 |
|
|
|
|
| 504 |
async for record in results:
|
| 505 |
source_node = record["n"]
|
| 506 |
connected_node = record["connected"]
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
source_label = (
|
| 509 |
-
|
| 510 |
)
|
| 511 |
target_label = (
|
| 512 |
-
|
| 513 |
-
if connected_node and connected_node.labels
|
| 514 |
-
else None
|
| 515 |
)
|
| 516 |
|
| 517 |
if source_label and target_label:
|
|
@@ -520,7 +483,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 520 |
await results.consume() # Ensure results are consumed
|
| 521 |
return edges
|
| 522 |
except Exception as e:
|
| 523 |
-
logger.error(f"Error getting edges for node {
|
| 524 |
await results.consume() # Ensure results are consumed even on error
|
| 525 |
raise
|
| 526 |
except Exception as e:
|
|
@@ -547,8 +510,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 547 |
node_id: The unique identifier for the node (used as label)
|
| 548 |
node_data: Dictionary of node properties
|
| 549 |
"""
|
| 550 |
-
label = self._ensure_label(node_id)
|
| 551 |
properties = node_data
|
|
|
|
|
|
|
| 552 |
if "entity_id" not in properties:
|
| 553 |
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
|
| 554 |
|
|
@@ -556,13 +520,14 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 556 |
async with self._driver.session(database=self._DATABASE) as session:
|
| 557 |
|
| 558 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
| 559 |
-
query =
|
| 560 |
-
MERGE (n
|
| 561 |
SET n += $properties
|
| 562 |
-
|
|
|
|
| 563 |
result = await tx.run(query, properties=properties)
|
| 564 |
logger.debug(
|
| 565 |
-
f"Upserted node with
|
| 566 |
)
|
| 567 |
await result.consume() # Ensure result is fully consumed
|
| 568 |
|
|
@@ -583,51 +548,6 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 583 |
)
|
| 584 |
),
|
| 585 |
)
|
| 586 |
-
async def _get_unique_node_entity_id(self, node_label: str) -> str:
|
| 587 |
-
"""
|
| 588 |
-
Get the entity_id of a node with the given label, ensuring the node is unique.
|
| 589 |
-
|
| 590 |
-
Args:
|
| 591 |
-
node_label (str): Label of the node to check
|
| 592 |
-
|
| 593 |
-
Returns:
|
| 594 |
-
str: The entity_id of the unique node
|
| 595 |
-
|
| 596 |
-
Raises:
|
| 597 |
-
ValueError: If no node with the given label exists or if multiple nodes have the same label
|
| 598 |
-
"""
|
| 599 |
-
async with self._driver.session(
|
| 600 |
-
database=self._DATABASE, default_access_mode="READ"
|
| 601 |
-
) as session:
|
| 602 |
-
query = f"""
|
| 603 |
-
MATCH (n:`{node_label}`)
|
| 604 |
-
RETURN n, count(n) as node_count
|
| 605 |
-
"""
|
| 606 |
-
result = await session.run(query)
|
| 607 |
-
try:
|
| 608 |
-
records = await result.fetch(
|
| 609 |
-
2
|
| 610 |
-
) # We only need to know if there are 0, 1, or >1 nodes
|
| 611 |
-
|
| 612 |
-
if not records or records[0]["node_count"] == 0:
|
| 613 |
-
raise ValueError(
|
| 614 |
-
f"Neo4j: node with label '{node_label}' does not exist"
|
| 615 |
-
)
|
| 616 |
-
|
| 617 |
-
if records[0]["node_count"] > 1:
|
| 618 |
-
raise ValueError(
|
| 619 |
-
f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node"
|
| 620 |
-
)
|
| 621 |
-
|
| 622 |
-
node = records[0]["n"]
|
| 623 |
-
if "entity_id" not in node:
|
| 624 |
-
raise ValueError(
|
| 625 |
-
f"Neo4j: node with label '{node_label}' does not have an entity_id property"
|
| 626 |
-
)
|
| 627 |
-
|
| 628 |
-
return node["entity_id"]
|
| 629 |
-
finally:
|
| 630 |
-
await result.consume() # Ensure result is fully consumed
|
| 631 |
|
| 632 |
@retry(
|
| 633 |
stop=stop_after_attempt(3),
|
|
@@ -657,38 +577,30 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 657 |
Raises:
|
| 658 |
ValueError: If either source or target node does not exist or is not unique
|
| 659 |
"""
|
| 660 |
-
source_label = self._ensure_label(source_node_id)
|
| 661 |
-
target_label = self._ensure_label(target_node_id)
|
| 662 |
-
edge_properties = edge_data
|
| 663 |
-
|
| 664 |
-
# Get entity_ids for source and target nodes, ensuring they are unique
|
| 665 |
-
source_entity_id = await self._get_unique_node_entity_id(source_label)
|
| 666 |
-
target_entity_id = await self._get_unique_node_entity_id(target_label)
|
| 667 |
-
|
| 668 |
try:
|
|
|
|
| 669 |
async with self._driver.session(database=self._DATABASE) as session:
|
| 670 |
|
| 671 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
| 672 |
-
query =
|
| 673 |
-
MATCH (source
|
| 674 |
WITH source
|
| 675 |
-
MATCH (target
|
| 676 |
MERGE (source)-[r:DIRECTED]-(target)
|
| 677 |
SET r += $properties
|
| 678 |
RETURN r, source, target
|
| 679 |
"""
|
| 680 |
result = await tx.run(
|
| 681 |
query,
|
| 682 |
-
source_entity_id=
|
| 683 |
-
target_entity_id=
|
| 684 |
properties=edge_properties,
|
| 685 |
)
|
| 686 |
try:
|
| 687 |
-
records = await result.fetch(
|
| 688 |
if records:
|
| 689 |
logger.debug(
|
| 690 |
-
f"Upserted edge from '{
|
| 691 |
-
f"to '{target_label}' (entity_id: {target_entity_id}) "
|
| 692 |
f"with properties: {edge_properties}"
|
| 693 |
)
|
| 694 |
finally:
|
|
@@ -726,7 +638,6 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 726 |
Returns:
|
| 727 |
KnowledgeGraph: Complete connected subgraph for specified node
|
| 728 |
"""
|
| 729 |
-
label = node_label.strip('"')
|
| 730 |
result = KnowledgeGraph()
|
| 731 |
seen_nodes = set()
|
| 732 |
seen_edges = set()
|
|
@@ -735,7 +646,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 735 |
database=self._DATABASE, default_access_mode="READ"
|
| 736 |
) as session:
|
| 737 |
try:
|
| 738 |
-
if
|
| 739 |
main_query = """
|
| 740 |
MATCH (n)
|
| 741 |
OPTIONAL MATCH (n)-[r]-()
|
|
@@ -760,12 +671,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 760 |
# Main query uses partial matching
|
| 761 |
main_query = """
|
| 762 |
MATCH (start)
|
| 763 |
-
WHERE
|
| 764 |
CASE
|
| 765 |
-
WHEN $inclusive THEN
|
| 766 |
-
ELSE
|
| 767 |
END
|
| 768 |
-
)
|
| 769 |
WITH start
|
| 770 |
CALL apoc.path.subgraphAll(start, {
|
| 771 |
relationshipFilter: '',
|
|
@@ -799,7 +709,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 799 |
main_query,
|
| 800 |
{
|
| 801 |
"max_nodes": MAX_GRAPH_NODES,
|
| 802 |
-
"
|
| 803 |
"inclusive": inclusive,
|
| 804 |
"max_depth": max_depth,
|
| 805 |
"min_degree": min_degree,
|
|
@@ -818,7 +728,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 818 |
result.nodes.append(
|
| 819 |
KnowledgeGraphNode(
|
| 820 |
id=f"{node_id}",
|
| 821 |
-
labels=
|
| 822 |
properties=dict(node),
|
| 823 |
)
|
| 824 |
)
|
|
@@ -849,7 +759,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 849 |
|
| 850 |
except neo4jExceptions.ClientError as e:
|
| 851 |
logger.warning(f"APOC plugin error: {str(e)}")
|
| 852 |
-
if
|
| 853 |
logger.warning(
|
| 854 |
"Neo4j: falling back to basic Cypher recursive search..."
|
| 855 |
)
|
|
@@ -857,12 +767,12 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 857 |
logger.warning(
|
| 858 |
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
|
| 859 |
)
|
| 860 |
-
return await self._robust_fallback(
|
| 861 |
|
| 862 |
return result
|
| 863 |
|
| 864 |
async def _robust_fallback(
|
| 865 |
-
self,
|
| 866 |
) -> KnowledgeGraph:
|
| 867 |
"""
|
| 868 |
Fallback implementation when APOC plugin is not available or incompatible.
|
|
@@ -895,12 +805,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 895 |
database=self._DATABASE, default_access_mode="READ"
|
| 896 |
) as session:
|
| 897 |
query = """
|
| 898 |
-
MATCH (a)-[r]-(b)
|
| 899 |
-
WHERE id(a) = toInteger($node_id)
|
| 900 |
WITH r, b, id(r) as edge_id, id(b) as target_id
|
| 901 |
RETURN r, b, edge_id, target_id
|
| 902 |
"""
|
| 903 |
-
results = await session.run(query,
|
| 904 |
|
| 905 |
# Get all records and release database connection
|
| 906 |
records = await results.fetch(
|
|
@@ -928,14 +837,14 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 928 |
edge_id = str(record["edge_id"])
|
| 929 |
if edge_id not in visited_edges:
|
| 930 |
b_node = record["b"]
|
| 931 |
-
target_id =
|
| 932 |
|
| 933 |
-
if
|
| 934 |
# Create KnowledgeGraphNode for target
|
| 935 |
target_node = KnowledgeGraphNode(
|
| 936 |
id=f"{target_id}",
|
| 937 |
-
labels=
|
| 938 |
-
properties=dict(b_node),
|
| 939 |
)
|
| 940 |
|
| 941 |
# Create KnowledgeGraphEdge
|
|
@@ -961,11 +870,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 961 |
async with self._driver.session(
|
| 962 |
database=self._DATABASE, default_access_mode="READ"
|
| 963 |
) as session:
|
| 964 |
-
query =
|
| 965 |
-
MATCH (n
|
| 966 |
RETURN id(n) as node_id, n
|
| 967 |
"""
|
| 968 |
-
node_result = await session.run(query)
|
| 969 |
try:
|
| 970 |
node_record = await node_result.single()
|
| 971 |
if not node_record:
|
|
@@ -973,9 +882,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 973 |
|
| 974 |
# Create initial KnowledgeGraphNode
|
| 975 |
start_node = KnowledgeGraphNode(
|
| 976 |
-
id=f"{node_record['
|
| 977 |
-
labels=
|
| 978 |
-
properties=dict(node_record["n"]),
|
| 979 |
)
|
| 980 |
finally:
|
| 981 |
await node_result.consume() # Ensure results are consumed
|
|
@@ -999,11 +908,10 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 999 |
|
| 1000 |
# Method 2: Query compatible with older versions
|
| 1001 |
query = """
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
ORDER BY label
|
| 1007 |
"""
|
| 1008 |
result = await session.run(query)
|
| 1009 |
labels = []
|
|
@@ -1034,15 +942,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 1034 |
Args:
|
| 1035 |
node_id: The label of the node to delete
|
| 1036 |
"""
|
| 1037 |
-
label = self._ensure_label(node_id)
|
| 1038 |
-
|
| 1039 |
async def _do_delete(tx: AsyncManagedTransaction):
|
| 1040 |
-
query =
|
| 1041 |
-
MATCH (n
|
| 1042 |
DETACH DELETE n
|
| 1043 |
"""
|
| 1044 |
-
result = await tx.run(query)
|
| 1045 |
-
logger.debug(f"Deleted node with label '{
|
| 1046 |
await result.consume() # Ensure result is fully consumed
|
| 1047 |
|
| 1048 |
try:
|
|
@@ -1092,16 +998,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|
| 1092 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
| 1093 |
"""
|
| 1094 |
for source, target in edges:
|
| 1095 |
-
source_label = self._ensure_label(source)
|
| 1096 |
-
target_label = self._ensure_label(target)
|
| 1097 |
-
|
| 1098 |
async def _do_delete_edge(tx: AsyncManagedTransaction):
|
| 1099 |
-
query =
|
| 1100 |
-
MATCH (source
|
| 1101 |
DELETE r
|
| 1102 |
"""
|
| 1103 |
-
result = await tx.run(query)
|
| 1104 |
-
logger.debug(f"Deleted edge from '{
|
| 1105 |
await result.consume() # Ensure result is fully consumed
|
| 1106 |
|
| 1107 |
try:
|
|
|
|
| 176 |
# Noe4J handles persistence automatically
|
| 177 |
pass
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
async def has_node(self, node_id: str) -> bool:
|
| 180 |
"""
|
| 181 |
Check if a node with the given label exists in the database
|
|
|
|
| 190 |
ValueError: If node_id is invalid
|
| 191 |
Exception: If there is an error executing the query
|
| 192 |
"""
|
|
|
|
| 193 |
async with self._driver.session(
|
| 194 |
database=self._DATABASE, default_access_mode="READ"
|
| 195 |
) as session:
|
| 196 |
try:
|
| 197 |
+
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
|
| 198 |
+
result = await session.run(query, entity_id = node_id)
|
| 199 |
single_result = await result.single()
|
| 200 |
await result.consume() # Ensure result is fully consumed
|
| 201 |
return single_result["node_exists"]
|
| 202 |
except Exception as e:
|
| 203 |
logger.error(
|
| 204 |
+
f"Error checking node existence for {node_id}: {str(e)}"
|
| 205 |
)
|
| 206 |
await result.consume() # Ensure results are consumed even on error
|
| 207 |
raise
|
|
|
|
| 221 |
ValueError: If either node_id is invalid
|
| 222 |
Exception: If there is an error executing the query
|
| 223 |
"""
|
|
|
|
|
|
|
|
|
|
| 224 |
async with self._driver.session(
|
| 225 |
database=self._DATABASE, default_access_mode="READ"
|
| 226 |
) as session:
|
| 227 |
try:
|
| 228 |
query = (
|
| 229 |
+
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
|
| 230 |
"RETURN COUNT(r) > 0 AS edgeExists"
|
| 231 |
)
|
| 232 |
+
result = await session.run(query, source_entity_id = source_node_id, target_entity_id = target_node_id)
|
| 233 |
single_result = await result.single()
|
| 234 |
await result.consume() # Ensure result is fully consumed
|
| 235 |
return single_result["edgeExists"]
|
| 236 |
except Exception as e:
|
| 237 |
logger.error(
|
| 238 |
+
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
|
| 239 |
)
|
| 240 |
await result.consume() # Ensure results are consumed even on error
|
| 241 |
raise
|
|
|
|
| 254 |
ValueError: If node_id is invalid
|
| 255 |
Exception: If there is an error executing the query
|
| 256 |
"""
|
|
|
|
| 257 |
async with self._driver.session(
|
| 258 |
database=self._DATABASE, default_access_mode="READ"
|
| 259 |
) as session:
|
| 260 |
try:
|
| 261 |
+
query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
|
| 262 |
+
result = await session.run(query, entity_id=node_id)
|
| 263 |
try:
|
| 264 |
records = await result.fetch(
|
| 265 |
2
|
|
|
|
| 267 |
|
| 268 |
if len(records) > 1:
|
| 269 |
logger.warning(
|
| 270 |
+
f"Multiple nodes found with label '{node_id}'. Using first node."
|
| 271 |
)
|
| 272 |
if records:
|
| 273 |
node = records[0]["n"]
|
| 274 |
node_dict = dict(node)
|
| 275 |
+
# Remove base label from labels list if it exists
|
| 276 |
+
if "labels" in node_dict:
|
| 277 |
+
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
| 278 |
+
logger.debug(f"Neo4j query node {query} return: {node_dict}")
|
| 279 |
return node_dict
|
| 280 |
return None
|
| 281 |
finally:
|
| 282 |
await result.consume() # Ensure result is fully consumed
|
| 283 |
except Exception as e:
|
| 284 |
+
logger.error(f"Error getting node for {node_id}: {str(e)}")
|
| 285 |
raise
|
| 286 |
|
| 287 |
async def node_degree(self, node_id: str) -> int:
|
|
|
|
| 299 |
ValueError: If node_id is invalid
|
| 300 |
Exception: If there is an error executing the query
|
| 301 |
"""
|
|
|
|
|
|
|
| 302 |
async with self._driver.session(
|
| 303 |
database=self._DATABASE, default_access_mode="READ"
|
| 304 |
) as session:
|
| 305 |
try:
|
| 306 |
+
query = """
|
| 307 |
+
MATCH (n:base {entity_id: $entity_id})
|
| 308 |
OPTIONAL MATCH (n)-[r]-()
|
| 309 |
+
RETURN COUNT(r) AS degree
|
| 310 |
"""
|
| 311 |
+
result = await session.run(query, entity_id = node_id)
|
| 312 |
try:
|
| 313 |
+
record = await result.single()
|
| 314 |
|
| 315 |
+
if not record:
|
| 316 |
logger.warning(
|
| 317 |
+
f"No node found with label '{node_id}'"
|
| 318 |
)
|
| 319 |
return 0
|
| 320 |
|
| 321 |
+
degree = record["degree"]
|
| 322 |
+
logger.debug("Neo4j query node degree for {node_id} return: {degree}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
return degree
|
| 324 |
finally:
|
| 325 |
await result.consume() # Ensure result is fully consumed
|
| 326 |
except Exception as e:
|
| 327 |
logger.error(
|
| 328 |
+
f"Error getting node degree for {node_id}: {str(e)}"
|
| 329 |
)
|
| 330 |
raise
|
| 331 |
|
|
|
|
| 339 |
Returns:
|
| 340 |
int: Sum of the degrees of both nodes
|
| 341 |
"""
|
| 342 |
+
src_degree = await self.node_degree(src_id)
|
| 343 |
+
trg_degree = await self.node_degree(tgt_id)
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
# Convert None to 0 for addition
|
| 346 |
src_degree = 0 if src_degree is None else src_degree
|
|
|
|
| 366 |
Exception: If there is an error executing the query
|
| 367 |
"""
|
| 368 |
try:
|
|
|
|
|
|
|
|
|
|
| 369 |
async with self._driver.session(
|
| 370 |
database=self._DATABASE, default_access_mode="READ"
|
| 371 |
) as session:
|
| 372 |
+
query = """
|
| 373 |
+
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
|
| 374 |
RETURN properties(r) as edge_properties
|
| 375 |
"""
|
| 376 |
+
result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id)
|
|
|
|
| 377 |
try:
|
| 378 |
records = await result.fetch(2)
|
| 379 |
|
| 380 |
if len(records) > 1:
|
| 381 |
logger.warning(
|
| 382 |
+
f"Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge."
|
| 383 |
)
|
| 384 |
if records:
|
| 385 |
try:
|
|
|
|
| 396 |
if key not in edge_result:
|
| 397 |
edge_result[key] = default_value
|
| 398 |
logger.warning(
|
| 399 |
+
f"Edge between {source_node_id} and {target_node_id} "
|
| 400 |
f"missing {key}, using default: {default_value}"
|
| 401 |
)
|
| 402 |
|
|
|
|
| 406 |
return edge_result
|
| 407 |
except (KeyError, TypeError, ValueError) as e:
|
| 408 |
logger.error(
|
| 409 |
+
f"Error processing edge properties between {source_node_id} "
|
| 410 |
+
f"and {target_node_id}: {str(e)}"
|
| 411 |
)
|
| 412 |
# Return default edge properties on error
|
| 413 |
return {
|
|
|
|
| 418 |
}
|
| 419 |
|
| 420 |
logger.debug(
|
| 421 |
+
f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
|
| 422 |
)
|
| 423 |
# Return default edge properties when no edge found
|
| 424 |
return {
|
|
|
|
| 451 |
Exception: If there is an error executing the query
|
| 452 |
"""
|
| 453 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
async with self._driver.session(
|
| 455 |
database=self._DATABASE, default_access_mode="READ"
|
| 456 |
) as session:
|
| 457 |
try:
|
| 458 |
+
query = """MATCH (n:base {entity_id: $entity_id})
|
| 459 |
+
OPTIONAL MATCH (n)-[r]-(connected:base)
|
| 460 |
+
WHERE connected.entity_id IS NOT NULL
|
| 461 |
+
RETURN n, r, connected"""
|
| 462 |
+
results = await session.run(query, entity_id=source_node_id)
|
| 463 |
|
| 464 |
+
edges = []
|
| 465 |
async for record in results:
|
| 466 |
source_node = record["n"]
|
| 467 |
connected_node = record["connected"]
|
| 468 |
|
| 469 |
+
# Skip if either node is None
|
| 470 |
+
if not source_node or not connected_node:
|
| 471 |
+
continue
|
| 472 |
+
|
| 473 |
source_label = (
|
| 474 |
+
source_node.get("entity_id") if source_node.get("entity_id") else None
|
| 475 |
)
|
| 476 |
target_label = (
|
| 477 |
+
connected_node.get("entity_id") if connected_node.get("entity_id") else None
|
|
|
|
|
|
|
| 478 |
)
|
| 479 |
|
| 480 |
if source_label and target_label:
|
|
|
|
| 483 |
await results.consume() # Ensure results are consumed
|
| 484 |
return edges
|
| 485 |
except Exception as e:
|
| 486 |
+
logger.error(f"Error getting edges for node {source_node_id}: {str(e)}")
|
| 487 |
await results.consume() # Ensure results are consumed even on error
|
| 488 |
raise
|
| 489 |
except Exception as e:
|
|
|
|
| 510 |
node_id: The unique identifier for the node (used as label)
|
| 511 |
node_data: Dictionary of node properties
|
| 512 |
"""
|
|
|
|
| 513 |
properties = node_data
|
| 514 |
+
entity_type = properties["entity_type"]
|
| 515 |
+
entity_id = properties["entity_id"]
|
| 516 |
if "entity_id" not in properties:
|
| 517 |
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
|
| 518 |
|
|
|
|
| 520 |
async with self._driver.session(database=self._DATABASE) as session:
|
| 521 |
|
| 522 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
| 523 |
+
query = """
|
| 524 |
+
MERGE (n:base {entity_id: $properties.entity_id})
|
| 525 |
SET n += $properties
|
| 526 |
+
SET n:`%s`
|
| 527 |
+
""" % entity_type
|
| 528 |
result = await tx.run(query, properties=properties)
|
| 529 |
logger.debug(
|
| 530 |
+
f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
|
| 531 |
)
|
| 532 |
await result.consume() # Ensure result is fully consumed
|
| 533 |
|
|
|
|
| 548 |
)
|
| 549 |
),
|
| 550 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
|
| 552 |
@retry(
|
| 553 |
stop=stop_after_attempt(3),
|
|
|
|
| 577 |
Raises:
|
| 578 |
ValueError: If either source or target node does not exist or is not unique
|
| 579 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
try:
|
| 581 |
+
edge_properties = edge_data
|
| 582 |
async with self._driver.session(database=self._DATABASE) as session:
|
| 583 |
|
| 584 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
| 585 |
+
query = """
|
| 586 |
+
MATCH (source:base {entity_id: $source_entity_id})
|
| 587 |
WITH source
|
| 588 |
+
MATCH (target:base {entity_id: $target_entity_id})
|
| 589 |
MERGE (source)-[r:DIRECTED]-(target)
|
| 590 |
SET r += $properties
|
| 591 |
RETURN r, source, target
|
| 592 |
"""
|
| 593 |
result = await tx.run(
|
| 594 |
query,
|
| 595 |
+
source_entity_id=source_node_id,
|
| 596 |
+
target_entity_id=target_node_id,
|
| 597 |
properties=edge_properties,
|
| 598 |
)
|
| 599 |
try:
|
| 600 |
+
records = await result.fetch(2)
|
| 601 |
if records:
|
| 602 |
logger.debug(
|
| 603 |
+
f"Upserted edge from '{source_node_id}' to '{target_node_id}'"
|
|
|
|
| 604 |
f"with properties: {edge_properties}"
|
| 605 |
)
|
| 606 |
finally:
|
|
|
|
| 638 |
Returns:
|
| 639 |
KnowledgeGraph: Complete connected subgraph for specified node
|
| 640 |
"""
|
|
|
|
| 641 |
result = KnowledgeGraph()
|
| 642 |
seen_nodes = set()
|
| 643 |
seen_edges = set()
|
|
|
|
| 646 |
database=self._DATABASE, default_access_mode="READ"
|
| 647 |
) as session:
|
| 648 |
try:
|
| 649 |
+
if node_label == "*":
|
| 650 |
main_query = """
|
| 651 |
MATCH (n)
|
| 652 |
OPTIONAL MATCH (n)-[r]-()
|
|
|
|
| 671 |
# Main query uses partial matching
|
| 672 |
main_query = """
|
| 673 |
MATCH (start)
|
| 674 |
+
WHERE
|
| 675 |
CASE
|
| 676 |
+
WHEN $inclusive THEN start.entity_id CONTAINS $entity_id
|
| 677 |
+
ELSE start.entity_id = $entity_id
|
| 678 |
END
|
|
|
|
| 679 |
WITH start
|
| 680 |
CALL apoc.path.subgraphAll(start, {
|
| 681 |
relationshipFilter: '',
|
|
|
|
| 709 |
main_query,
|
| 710 |
{
|
| 711 |
"max_nodes": MAX_GRAPH_NODES,
|
| 712 |
+
"entity_id": node_label,
|
| 713 |
"inclusive": inclusive,
|
| 714 |
"max_depth": max_depth,
|
| 715 |
"min_degree": min_degree,
|
|
|
|
| 728 |
result.nodes.append(
|
| 729 |
KnowledgeGraphNode(
|
| 730 |
id=f"{node_id}",
|
| 731 |
+
labels=[label for label in node.labels if label != "base"],
|
| 732 |
properties=dict(node),
|
| 733 |
)
|
| 734 |
)
|
|
|
|
| 759 |
|
| 760 |
except neo4jExceptions.ClientError as e:
|
| 761 |
logger.warning(f"APOC plugin error: {str(e)}")
|
| 762 |
+
if node_label != "*":
|
| 763 |
logger.warning(
|
| 764 |
"Neo4j: falling back to basic Cypher recursive search..."
|
| 765 |
)
|
|
|
|
| 767 |
logger.warning(
|
| 768 |
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
|
| 769 |
)
|
| 770 |
+
return await self._robust_fallback(node_label, max_depth, min_degree)
|
| 771 |
|
| 772 |
return result
|
| 773 |
|
| 774 |
async def _robust_fallback(
|
| 775 |
+
self, node_label: str, max_depth: int, min_degree: int = 0
|
| 776 |
) -> KnowledgeGraph:
|
| 777 |
"""
|
| 778 |
Fallback implementation when APOC plugin is not available or incompatible.
|
|
|
|
| 805 |
database=self._DATABASE, default_access_mode="READ"
|
| 806 |
) as session:
|
| 807 |
query = """
|
| 808 |
+
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
|
|
|
| 809 |
WITH r, b, id(r) as edge_id, id(b) as target_id
|
| 810 |
RETURN r, b, edge_id, target_id
|
| 811 |
"""
|
| 812 |
+
results = await session.run(query, entity_id=node.id)
|
| 813 |
|
| 814 |
# Get all records and release database connection
|
| 815 |
records = await results.fetch(
|
|
|
|
| 837 |
edge_id = str(record["edge_id"])
|
| 838 |
if edge_id not in visited_edges:
|
| 839 |
b_node = record["b"]
|
| 840 |
+
target_id = b_node.get("entity_id")
|
| 841 |
|
| 842 |
+
if target_id: # Only process if target node has entity_id
|
| 843 |
# Create KnowledgeGraphNode for target
|
| 844 |
target_node = KnowledgeGraphNode(
|
| 845 |
id=f"{target_id}",
|
| 846 |
+
labels=[label for label in b_node.labels if label != "base"],
|
| 847 |
+
properties=dict(b_node.properties),
|
| 848 |
)
|
| 849 |
|
| 850 |
# Create KnowledgeGraphEdge
|
|
|
|
| 870 |
async with self._driver.session(
|
| 871 |
database=self._DATABASE, default_access_mode="READ"
|
| 872 |
) as session:
|
| 873 |
+
query = """
|
| 874 |
+
MATCH (n:base {entity_id: $entity_id})
|
| 875 |
RETURN id(n) as node_id, n
|
| 876 |
"""
|
| 877 |
+
node_result = await session.run(query, entity_id=node_label)
|
| 878 |
try:
|
| 879 |
node_record = await node_result.single()
|
| 880 |
if not node_record:
|
|
|
|
| 882 |
|
| 883 |
# Create initial KnowledgeGraphNode
|
| 884 |
start_node = KnowledgeGraphNode(
|
| 885 |
+
id=f"{node_record['n'].get('entity_id')}",
|
| 886 |
+
labels=[label for label in node_record["n"].labels if label != "base"],
|
| 887 |
+
properties=dict(node_record["n"].properties),
|
| 888 |
)
|
| 889 |
finally:
|
| 890 |
await node_result.consume() # Ensure results are consumed
|
|
|
|
| 908 |
|
| 909 |
# Method 2: Query compatible with older versions
|
| 910 |
query = """
|
| 911 |
+
MATCH (n)
|
| 912 |
+
WHERE n.entity_id IS NOT NULL
|
| 913 |
+
RETURN DISTINCT n.entity_id AS label
|
| 914 |
+
ORDER BY label
|
|
|
|
| 915 |
"""
|
| 916 |
result = await session.run(query)
|
| 917 |
labels = []
|
|
|
|
| 942 |
Args:
|
| 943 |
node_id: The label of the node to delete
|
| 944 |
"""
|
|
|
|
|
|
|
| 945 |
async def _do_delete(tx: AsyncManagedTransaction):
|
| 946 |
+
query = """
|
| 947 |
+
MATCH (n:base {entity_id: $entity_id})
|
| 948 |
DETACH DELETE n
|
| 949 |
"""
|
| 950 |
+
result = await tx.run(query, entity_id=node_id)
|
| 951 |
+
logger.debug(f"Deleted node with label '{node_id}'")
|
| 952 |
await result.consume() # Ensure result is fully consumed
|
| 953 |
|
| 954 |
try:
|
|
|
|
| 998 |
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
| 999 |
"""
|
| 1000 |
for source, target in edges:
|
|
|
|
|
|
|
|
|
|
| 1001 |
async def _do_delete_edge(tx: AsyncManagedTransaction):
|
| 1002 |
+
query = """
|
| 1003 |
+
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
|
| 1004 |
DELETE r
|
| 1005 |
"""
|
| 1006 |
+
result = await tx.run(query, source_entity_id=source, target_entity_id=target)
|
| 1007 |
+
logger.debug(f"Deleted edge from '{source}' to '{target}'")
|
| 1008 |
await result.consume() # Ensure result is fully consumed
|
| 1009 |
|
| 1010 |
try:
|