yangdx
		
	commited on
		
		
					Commit 
							
							·
						
						7764b80
	
1
								Parent(s):
							
							803d479
								
fix: duplicate nodes for same entity(label) problem in Neo4j
Browse files- Add entity_id field as key in Neo4j nodes
- Use  entity_id for nodes retrival and upsert
- lightrag/kg/neo4j_impl.py +72 -34
 - lightrag/operate.py +2 -0
 
    	
        lightrag/kg/neo4j_impl.py
    CHANGED
    
    | 
         @@ -280,12 +280,10 @@ class Neo4JStorage(BaseGraphStorage): 
     | 
|
| 280 | 
         
             
                        database=self._DATABASE, default_access_mode="READ"
         
     | 
| 281 | 
         
             
                    ) as session:
         
     | 
| 282 | 
         
             
                        try:
         
     | 
| 283 | 
         
            -
                            query = f"MATCH (n:`{entity_name_label}`) RETURN n"
         
     | 
| 284 | 
         
            -
                            result = await session.run(query)
         
     | 
| 285 | 
         
             
                            try:
         
     | 
| 286 | 
         
            -
                                records = await result.fetch(
         
     | 
| 287 | 
         
            -
                                    2
         
     | 
| 288 | 
         
            -
                                )  # Get up to 2 records to check for duplicates
         
     | 
| 289 | 
         | 
| 290 | 
         
             
                                if len(records) > 1:
         
     | 
| 291 | 
         
             
                                    logger.warning(
         
     | 
| 
         @@ -549,12 +547,14 @@ class Neo4JStorage(BaseGraphStorage): 
     | 
|
| 549 | 
         
             
                    """
         
     | 
| 550 | 
         
             
                    label = self._ensure_label(node_id)
         
     | 
| 551 | 
         
             
                    properties = node_data
         
     | 
| 
         | 
|
| 
         | 
|
| 552 | 
         | 
| 553 | 
         
             
                    try:
         
     | 
| 554 | 
         
             
                        async with self._driver.session(database=self._DATABASE) as session:
         
     | 
| 555 | 
         
             
                            async def execute_upsert(tx: AsyncManagedTransaction):
         
     | 
| 556 | 
         
             
                                query = f"""
         
     | 
| 557 | 
         
            -
                                MERGE (n:`{label}`)
         
     | 
| 558 | 
         
             
                                SET n += $properties
         
     | 
| 559 | 
         
             
                                """
         
     | 
| 560 | 
         
             
                                result = await tx.run(query, properties=properties)
         
     | 
| 
         @@ -568,6 +568,56 @@ class Neo4JStorage(BaseGraphStorage): 
     | 
|
| 568 | 
         
             
                        logger.error(f"Error during upsert: {str(e)}")
         
     | 
| 569 | 
         
             
                        raise
         
     | 
| 570 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 571 | 
         
             
                @retry(
         
     | 
| 572 | 
         
             
                    stop=stop_after_attempt(3),
         
     | 
| 573 | 
         
             
                    wait=wait_exponential(multiplier=1, min=4, max=10),
         
     | 
| 
         @@ -585,7 +635,8 @@ class Neo4JStorage(BaseGraphStorage): 
     | 
|
| 585 | 
         
             
                ) -> None:
         
     | 
| 586 | 
         
             
                    """
         
     | 
| 587 | 
         
             
                    Upsert an edge and its properties between two nodes identified by their labels.
         
     | 
| 588 | 
         
            -
                     
     | 
| 
         | 
|
| 589 | 
         | 
| 590 | 
         
             
                    Args:
         
     | 
| 591 | 
         
             
                        source_node_id (str): Label of the source node (used as identifier)
         
     | 
| 
         @@ -593,52 +644,39 @@ class Neo4JStorage(BaseGraphStorage): 
     | 
|
| 593 | 
         
             
                        edge_data (dict): Dictionary of properties to set on the edge
         
     | 
| 594 | 
         | 
| 595 | 
         
             
                    Raises:
         
     | 
| 596 | 
         
            -
                        ValueError: If either source or target node does not exist
         
     | 
| 597 | 
         
             
                    """
         
     | 
| 598 | 
         
             
                    source_label = self._ensure_label(source_node_id)
         
     | 
| 599 | 
         
             
                    target_label = self._ensure_label(target_node_id)
         
     | 
| 600 | 
         
             
                    edge_properties = edge_data
         
     | 
| 601 | 
         | 
| 602 | 
         
            -
                    #  
     | 
| 603 | 
         
            -
                     
     | 
| 604 | 
         
            -
                     
     | 
| 605 | 
         
            -
             
     | 
| 606 | 
         
            -
                    if not source_exists:
         
     | 
| 607 | 
         
            -
                        raise ValueError(
         
     | 
| 608 | 
         
            -
                            f"Neo4j: source node with label '{source_label}' does not exist"
         
     | 
| 609 | 
         
            -
                        )
         
     | 
| 610 | 
         
            -
                    if not target_exists:
         
     | 
| 611 | 
         
            -
                        raise ValueError(
         
     | 
| 612 | 
         
            -
                            f"Neo4j: target node with label '{target_label}' does not exist"
         
     | 
| 613 | 
         
            -
                        )
         
     | 
| 614 | 
         | 
| 615 | 
         
             
                    try:
         
     | 
| 616 | 
         
             
                        async with self._driver.session(database=self._DATABASE) as session:
         
     | 
| 617 | 
         
             
                            async def execute_upsert(tx: AsyncManagedTransaction):
         
     | 
| 618 | 
         
             
                                query = f"""
         
     | 
| 619 | 
         
            -
                                MATCH (source:`{source_label}`)
         
     | 
| 620 | 
         
             
                                WITH source
         
     | 
| 621 | 
         
            -
                                MATCH (target:`{target_label}`)
         
     | 
| 622 | 
         
             
                                MERGE (source)-[r:DIRECTED]-(target)
         
     | 
| 623 | 
         
             
                                SET r += $properties
         
     | 
| 624 | 
         
             
                                RETURN r, source, target
         
     | 
| 625 | 
         
             
                                """
         
     | 
| 626 | 
         
            -
                                result = await tx.run( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 627 | 
         
             
                                try:
         
     | 
| 628 | 
         
             
                                    records = await result.fetch(100)
         
     | 
| 629 | 
         
            -
                                    if len(records) > 1:
         
     | 
| 630 | 
         
            -
                                        source_nodes = [dict(r['source']) for r in records]
         
     | 
| 631 | 
         
            -
                                        target_nodes = [dict(r['target']) for r in records]
         
     | 
| 632 | 
         
            -
                                        logger.warning(
         
     | 
| 633 | 
         
            -
                                            f"Multiple edges created: found {len(records)} results for edge between "
         
     | 
| 634 | 
         
            -
                                            f"source label '{source_label}' and target label '{target_label}'. "
         
     | 
| 635 | 
         
            -
                                            f"Source nodes: {source_nodes}, "
         
     | 
| 636 | 
         
            -
                                            f"Target nodes: {target_nodes}. "
         
     | 
| 637 | 
         
            -
                                            "Using first edge only."
         
     | 
| 638 | 
         
            -
                                        )
         
     | 
| 639 | 
         
             
                                    if records:
         
     | 
| 640 | 
         
             
                                        logger.debug(
         
     | 
| 641 | 
         
            -
                                            f"Upserted edge from '{source_label}'  
     | 
| 
         | 
|
| 642 | 
         
             
                                            f"with properties: {edge_properties}"
         
     | 
| 643 | 
         
             
                                        )
         
     | 
| 644 | 
         
             
                                finally:
         
     | 
| 
         | 
|
| 280 | 
         
             
                        database=self._DATABASE, default_access_mode="READ"
         
     | 
| 281 | 
         
             
                    ) as session:
         
     | 
| 282 | 
         
             
                        try:
         
     | 
| 283 | 
         
            +
                            query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n"
         
     | 
| 284 | 
         
            +
                            result = await session.run(query, entity_id=entity_name_label)
         
     | 
| 285 | 
         
             
                            try:
         
     | 
| 286 | 
         
            +
                                records = await result.fetch(2)  # Get 2 records for duplication check
         
     | 
| 
         | 
|
| 
         | 
|
| 287 | 
         | 
| 288 | 
         
             
                                if len(records) > 1:
         
     | 
| 289 | 
         
             
                                    logger.warning(
         
     | 
| 
         | 
|
| 547 | 
         
             
                    """
         
     | 
| 548 | 
         
             
                    label = self._ensure_label(node_id)
         
     | 
| 549 | 
         
             
                    properties = node_data
         
     | 
| 550 | 
         
            +
                    if "entity_id" not in properties:
         
     | 
| 551 | 
         
            +
                        raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
         
     | 
| 552 | 
         | 
| 553 | 
         
             
                    try:
         
     | 
| 554 | 
         
             
                        async with self._driver.session(database=self._DATABASE) as session:
         
     | 
| 555 | 
         
             
                            async def execute_upsert(tx: AsyncManagedTransaction):
         
     | 
| 556 | 
         
             
                                query = f"""
         
     | 
| 557 | 
         
            +
                                MERGE (n:`{label}` {{entity_id: $properties.entity_id}})
         
     | 
| 558 | 
         
             
                                SET n += $properties
         
     | 
| 559 | 
         
             
                                """
         
     | 
| 560 | 
         
             
                                result = await tx.run(query, properties=properties)
         
     | 
| 
         | 
|
| 568 | 
         
             
                        logger.error(f"Error during upsert: {str(e)}")
         
     | 
| 569 | 
         
             
                        raise
         
     | 
| 570 | 
         | 
| 571 | 
         
            +
                @retry(
         
     | 
| 572 | 
         
            +
                    stop=stop_after_attempt(3),
         
     | 
| 573 | 
         
            +
                    wait=wait_exponential(multiplier=1, min=4, max=10),
         
     | 
| 574 | 
         
            +
                    retry=retry_if_exception_type(
         
     | 
| 575 | 
         
            +
                        (
         
     | 
| 576 | 
         
            +
                            neo4jExceptions.ServiceUnavailable,
         
     | 
| 577 | 
         
            +
                            neo4jExceptions.TransientError,
         
     | 
| 578 | 
         
            +
                            neo4jExceptions.WriteServiceUnavailable,
         
     | 
| 579 | 
         
            +
                            neo4jExceptions.ClientError,
         
     | 
| 580 | 
         
            +
                        )
         
     | 
| 581 | 
         
            +
                    ),
         
     | 
| 582 | 
         
            +
                )
         
     | 
| 583 | 
         
            +
                async def _get_unique_node_entity_id(self, node_label: str) -> str:
         
     | 
| 584 | 
         
            +
                    """
         
     | 
| 585 | 
         
            +
                    Get the entity_id of a node with the given label, ensuring the node is unique.
         
     | 
| 586 | 
         
            +
             
     | 
| 587 | 
         
            +
                    Args:
         
     | 
| 588 | 
         
            +
                        node_label (str): Label of the node to check
         
     | 
| 589 | 
         
            +
             
     | 
| 590 | 
         
            +
                    Returns:
         
     | 
| 591 | 
         
            +
                        str: The entity_id of the unique node
         
     | 
| 592 | 
         
            +
             
     | 
| 593 | 
         
            +
                    Raises:
         
     | 
| 594 | 
         
            +
                        ValueError: If no node with the given label exists or if multiple nodes have the same label
         
     | 
| 595 | 
         
            +
                    """
         
     | 
| 596 | 
         
            +
                    async with self._driver.session(
         
     | 
| 597 | 
         
            +
                        database=self._DATABASE, default_access_mode="READ"
         
     | 
| 598 | 
         
            +
                    ) as session:
         
     | 
| 599 | 
         
            +
                        query = f"""
         
     | 
| 600 | 
         
            +
                        MATCH (n:`{node_label}`)
         
     | 
| 601 | 
         
            +
                        RETURN n, count(n) as node_count
         
     | 
| 602 | 
         
            +
                        """
         
     | 
| 603 | 
         
            +
                        result = await session.run(query)
         
     | 
| 604 | 
         
            +
                        try:
         
     | 
| 605 | 
         
            +
                            records = await result.fetch(2)  # We only need to know if there are 0, 1, or >1 nodes
         
     | 
| 606 | 
         
            +
                            
         
     | 
| 607 | 
         
            +
                            if not records or records[0]["node_count"] == 0:
         
     | 
| 608 | 
         
            +
                                raise ValueError(f"Neo4j: node with label '{node_label}' does not exist")
         
     | 
| 609 | 
         
            +
                            
         
     | 
| 610 | 
         
            +
                            if records[0]["node_count"] > 1:
         
     | 
| 611 | 
         
            +
                                raise ValueError(f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node")
         
     | 
| 612 | 
         
            +
                            
         
     | 
| 613 | 
         
            +
                            node = records[0]["n"]
         
     | 
| 614 | 
         
            +
                            if "entity_id" not in node:
         
     | 
| 615 | 
         
            +
                                raise ValueError(f"Neo4j: node with label '{node_label}' does not have an entity_id property")
         
     | 
| 616 | 
         
            +
                            
         
     | 
| 617 | 
         
            +
                            return node["entity_id"]
         
     | 
| 618 | 
         
            +
                        finally:
         
     | 
| 619 | 
         
            +
                            await result.consume()  # Ensure result is fully consumed
         
     | 
| 620 | 
         
            +
             
     | 
| 621 | 
         
             
                @retry(
         
     | 
| 622 | 
         
             
                    stop=stop_after_attempt(3),
         
     | 
| 623 | 
         
             
                    wait=wait_exponential(multiplier=1, min=4, max=10),
         
     | 
| 
         | 
|
| 635 | 
         
             
                ) -> None:
         
     | 
| 636 | 
         
             
                    """
         
     | 
| 637 | 
         
             
                    Upsert an edge and its properties between two nodes identified by their labels.
         
     | 
| 638 | 
         
            +
                    Ensures both source and target nodes exist and are unique before creating the edge.
         
     | 
| 639 | 
         
            +
                    Uses entity_id property to uniquely identify nodes.
         
     | 
| 640 | 
         | 
| 641 | 
         
             
                    Args:
         
     | 
| 642 | 
         
             
                        source_node_id (str): Label of the source node (used as identifier)
         
     | 
| 
         | 
|
| 644 | 
         
             
                        edge_data (dict): Dictionary of properties to set on the edge
         
     | 
| 645 | 
         | 
| 646 | 
         
             
                    Raises:
         
     | 
| 647 | 
         
            +
                        ValueError: If either source or target node does not exist or is not unique
         
     | 
| 648 | 
         
             
                    """
         
     | 
| 649 | 
         
             
                    source_label = self._ensure_label(source_node_id)
         
     | 
| 650 | 
         
             
                    target_label = self._ensure_label(target_node_id)
         
     | 
| 651 | 
         
             
                    edge_properties = edge_data
         
     | 
| 652 | 
         | 
| 653 | 
         
            +
                    # Get entity_ids for source and target nodes, ensuring they are unique
         
     | 
| 654 | 
         
            +
                    source_entity_id = await self._get_unique_node_entity_id(source_label)
         
     | 
| 655 | 
         
            +
                    target_entity_id = await self._get_unique_node_entity_id(target_label)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 656 | 
         | 
| 657 | 
         
             
                    try:
         
     | 
| 658 | 
         
             
                        async with self._driver.session(database=self._DATABASE) as session:
         
     | 
| 659 | 
         
             
                            async def execute_upsert(tx: AsyncManagedTransaction):
         
     | 
| 660 | 
         
             
                                query = f"""
         
     | 
| 661 | 
         
            +
                                MATCH (source:`{source_label}` {{entity_id: $source_entity_id}})
         
     | 
| 662 | 
         
             
                                WITH source
         
     | 
| 663 | 
         
            +
                                MATCH (target:`{target_label}` {{entity_id: $target_entity_id}})
         
     | 
| 664 | 
         
             
                                MERGE (source)-[r:DIRECTED]-(target)
         
     | 
| 665 | 
         
             
                                SET r += $properties
         
     | 
| 666 | 
         
             
                                RETURN r, source, target
         
     | 
| 667 | 
         
             
                                """
         
     | 
| 668 | 
         
            +
                                result = await tx.run(
         
     | 
| 669 | 
         
            +
                                    query, 
         
     | 
| 670 | 
         
            +
                                    source_entity_id=source_entity_id,
         
     | 
| 671 | 
         
            +
                                    target_entity_id=target_entity_id,
         
     | 
| 672 | 
         
            +
                                    properties=edge_properties
         
     | 
| 673 | 
         
            +
                                )
         
     | 
| 674 | 
         
             
                                try:
         
     | 
| 675 | 
         
             
                                    records = await result.fetch(100)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 676 | 
         
             
                                    if records:
         
     | 
| 677 | 
         
             
                                        logger.debug(
         
     | 
| 678 | 
         
            +
                                            f"Upserted edge from '{source_label}' (entity_id: {source_entity_id}) "
         
     | 
| 679 | 
         
            +
                                            f"to '{target_label}' (entity_id: {target_entity_id}) "
         
     | 
| 680 | 
         
             
                                            f"with properties: {edge_properties}"
         
     | 
| 681 | 
         
             
                                        )
         
     | 
| 682 | 
         
             
                                finally:
         
     | 
    	
        lightrag/operate.py
    CHANGED
    
    | 
         @@ -220,6 +220,7 @@ async def _merge_nodes_then_upsert( 
     | 
|
| 220 | 
         
             
                    entity_name, description, global_config
         
     | 
| 221 | 
         
             
                )
         
     | 
| 222 | 
         
             
                node_data = dict(
         
     | 
| 
         | 
|
| 223 | 
         
             
                    entity_type=entity_type,
         
     | 
| 224 | 
         
             
                    description=description,
         
     | 
| 225 | 
         
             
                    source_id=source_id,
         
     | 
| 
         @@ -301,6 +302,7 @@ async def _merge_edges_then_upsert( 
     | 
|
| 301 | 
         
             
                        await knowledge_graph_inst.upsert_node(
         
     | 
| 302 | 
         
             
                            need_insert_id,
         
     | 
| 303 | 
         
             
                            node_data={
         
     | 
| 
         | 
|
| 304 | 
         
             
                                "source_id": source_id,
         
     | 
| 305 | 
         
             
                                "description": description,
         
     | 
| 306 | 
         
             
                                "entity_type": "UNKNOWN",
         
     | 
| 
         | 
|
| 220 | 
         
             
                    entity_name, description, global_config
         
     | 
| 221 | 
         
             
                )
         
     | 
| 222 | 
         
             
                node_data = dict(
         
     | 
| 223 | 
         
            +
                    entity_id=entity_name,
         
     | 
| 224 | 
         
             
                    entity_type=entity_type,
         
     | 
| 225 | 
         
             
                    description=description,
         
     | 
| 226 | 
         
             
                    source_id=source_id,
         
     | 
| 
         | 
|
| 302 | 
         
             
                        await knowledge_graph_inst.upsert_node(
         
     | 
| 303 | 
         
             
                            need_insert_id,
         
     | 
| 304 | 
         
             
                            node_data={
         
     | 
| 305 | 
         
            +
                                "entity_id": need_insert_id,
         
     | 
| 306 | 
         
             
                                "source_id": source_id,
         
     | 
| 307 | 
         
             
                                "description": description,
         
     | 
| 308 | 
         
             
                                "entity_type": "UNKNOWN",
         
     |