AgentGraph / agentgraph /methods /baseline /clustering_method.py
wu981526092's picture
add
7bc750c
"""
Multi-Stage Clustering Knowledge Extraction Method
Implements a multi-stage clustering approach inspired by KGGen research.
This method performs initial extraction followed by iterative clustering
of entities and relationships to improve semantic consistency and reduce redundancy.
"""
# Import the LiteLLM fix FIRST, before any other imports that might use LiteLLM
import os
import sys
# Add the parent directory to the path to ensure imports work correctly
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
import logging
import time
from datetime import datetime
from typing import Any, Dict
from crewai import Agent, Crew, Process, Task
from evaluation.knowledge_extraction.baselines.base_method import BaseKnowledgeExtractionMethod
from evaluation.knowledge_extraction.utils.models import KnowledgeGraph
# Import shared prompt templates
from evaluation.knowledge_extraction.utils.prompts import (
ENTITY_EXTRACTION_INSTRUCTION_PROMPT,
ENTITY_EXTRACTION_SYSTEM_PROMPT,
GRAPH_BUILDER_SYSTEM_PROMPT,
RELATION_EXTRACTION_INSTRUCTION_PROMPT,
RELATION_EXTRACTION_SYSTEM_PROMPT,
)
from utils.fix_litellm_stop_param import * # This applies the patches # noqa: F403
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set higher log levels for noisy libraries
logging.getLogger("openai").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("litellm").setLevel(logging.WARNING)
logging.getLogger("chromadb").setLevel(logging.WARNING)
# Import models (copied from core)
# Set default verbosity level
verbose_level = 0
# Set environment variables
os.environ["OPENAI_MODEL_NAME"] = "gpt-5-mini"
class ClusteringKnowledgeExtractionMethod(BaseKnowledgeExtractionMethod):
"""Multi-stage clustering knowledge extraction method using CrewAI."""
def __init__(self, **kwargs):
super().__init__("clustering_method", **kwargs)
self._setup_agents_and_tasks()
# Test comment for code change detection
def _setup_agents_and_tasks(self):
"""Set up the CrewAI agents and tasks."""
# Create extraction agent (similar to unified approach)
self.extraction_agent = Agent(
role="Knowledge Graph Extractor",
goal="Extract comprehensive entities and relationships from agent system data",
backstory=f"{ENTITY_EXTRACTION_SYSTEM_PROMPT}\n\n{RELATION_EXTRACTION_SYSTEM_PROMPT}",
verbose=bool(verbose_level),
llm=os.environ["OPENAI_MODEL_NAME"],
)
# Create clustering agent for entity deduplication
self.entity_clustering_agent = Agent(
role="Entity Clustering Specialist",
goal="Identify and merge duplicate or similar entities to improve graph consistency",
backstory="""You are an expert in entity resolution and clustering. You can identify when
different entity mentions refer to the same real-world entity, even when they have slight
variations in naming, description, or representation.
You excel at:
- Identifying semantic equivalence between entities
- Merging entities with different tenses, plurality, or capitalization
- Consolidating entities that represent the same concept
- Maintaining entity relationships during clustering
You ensure the final entity set is clean, consistent, and free of redundancy.""",
verbose=bool(verbose_level),
llm=os.environ["OPENAI_MODEL_NAME"],
)
# Create relationship clustering agent
self.relationship_clustering_agent = Agent(
role="Relationship Clustering Specialist",
goal="Identify and merge duplicate or similar relationships to improve graph coherence",
backstory="""You are an expert in relationship analysis and clustering. You can identify when
different relationship expressions refer to the same underlying connection between entities.
You excel at:
- Identifying semantically equivalent relationships
- Merging relationships with different phrasings but same meaning
- Consolidating relationships that represent the same connection type
- Ensuring relationship consistency across the knowledge graph
You maintain the integrity of entity connections while improving relationship clarity.""",
verbose=bool(verbose_level),
llm=os.environ["OPENAI_MODEL_NAME"],
)
# Create validation agent
self.validation_agent = Agent(
role="Knowledge Graph Validator",
goal="Validate and finalize the clustered knowledge graph for quality and completeness",
backstory=GRAPH_BUILDER_SYSTEM_PROMPT,
verbose=bool(verbose_level),
llm=os.environ["OPENAI_MODEL_NAME"],
)
# Create extraction task
self.extraction_task = Task(
description=f"""
Extract comprehensive entities and relationships from the provided agent system data.
{ENTITY_EXTRACTION_INSTRUCTION_PROMPT}
Also extract relationships:
{RELATION_EXTRACTION_INSTRUCTION_PROMPT}
Output a complete initial knowledge graph with all extracted entities and relationships.
Focus on thoroughness and accuracy - clustering will happen in subsequent steps.
""",
agent=self.extraction_agent,
expected_output="A complete initial knowledge graph with comprehensive entities and relationships",
output_pydantic=KnowledgeGraph,
)
# Create entity clustering task
self.entity_clustering_task = Task(
description="""
Analyze the extracted entities and identify clusters of entities that represent the same concept.
You will receive the knowledge graph from the previous extraction task.
Your task is to:
1. ENTITY ANALYSIS - Group similar entities:
- Identify entities with same meaning but different expressions
- Look for variations in tense, plurality, capitalization
- Find entities that represent the same real-world concept
- Consider semantic similarity and contextual equivalence
2. CLUSTERING DECISIONS - For each cluster:
- Select the most representative entity as the canonical form
- Merge descriptions and properties from all cluster members
- Preserve all relevant information from clustered entities
- Maintain entity type consistency
3. RELATIONSHIP UPDATES - Update relationships:
- Replace clustered entity IDs with canonical entity IDs
- Ensure all relationships remain valid after clustering
- Remove duplicate relationships that may result from clustering
Output an updated knowledge graph with clustered entities and updated relationships.
Ensure no information is lost during the clustering process.
""",
agent=self.entity_clustering_agent,
expected_output="Knowledge graph with clustered entities and updated relationships",
context=[self.extraction_task],
output_pydantic=KnowledgeGraph,
)
# Create relationship clustering task
self.relationship_clustering_task = Task(
description="""
Analyze the relationships and identify clusters of relationships that represent the same connection type.
You will receive the knowledge graph from the previous entity clustering task.
Your task is to:
1. RELATIONSHIP ANALYSIS - Group similar relationships:
- Identify relationships with same meaning but different expressions
- Look for variations in phrasing, tense, or description
- Find relationships that represent the same connection type
- Consider semantic equivalence between relationship descriptions
2. CLUSTERING DECISIONS - For each relationship cluster:
- Select the most clear and representative relationship type
- Merge descriptions from all cluster members
- Preserve the most informative relationship description
- Maintain relationship directionality and constraints
3. GRAPH OPTIMIZATION - Optimize the relationship structure:
- Remove redundant relationships between same entity pairs
- Ensure relationship consistency across the graph
- Maintain logical coherence in relationship types
Output an optimized knowledge graph with clustered relationships and improved consistency.
""",
agent=self.relationship_clustering_agent,
expected_output="Knowledge graph with clustered relationships and improved consistency",
context=[self.entity_clustering_task],
output_pydantic=KnowledgeGraph,
)
# Create validation task
self.validation_task = Task(
description="""
Validate and finalize the clustered knowledge graph for quality and completeness.
You will receive the knowledge graph from the previous relationship clustering task.
Your task is to:
1. QUALITY VALIDATION - Check graph quality:
- Ensure all entities are properly connected
- Validate relationship consistency and logic
- Check for orphaned entities or broken connections
- Verify entity-relationship type compatibility
2. COMPLETENESS CHECK - Ensure completeness:
- Verify all important entities are captured
- Check that key relationships are represented
- Ensure system functionality is properly modeled
- Validate that the graph tells a complete story
3. FINALIZATION - Create final knowledge graph:
- Generate descriptive system name (3-7 words)
- Write comprehensive 2-3 sentence system summary
- Include metadata with processing statistics
- Ensure all components are reachable and connected
Output the final, validated knowledge graph ready for use.
""",
agent=self.validation_agent,
expected_output="Final validated knowledge graph with system summary and metadata",
context=[self.relationship_clustering_task],
output_pydantic=KnowledgeGraph,
)
# Create crew
self.clustering_crew = Crew(
agents=[self.extraction_agent, self.entity_clustering_agent, self.relationship_clustering_agent, self.validation_agent],
tasks=[self.extraction_task, self.entity_clustering_task, self.relationship_clustering_task, self.validation_task],
verbose=bool(verbose_level),
memory=False,
planning=False,
process=Process.sequential,
)
def _calculate_token_cost(self, total_tokens: int, prompt_tokens: int, completion_tokens: int, model_name: str) -> float:
"""
Calculate token cost based on model pricing.
Args:
total_tokens: Total number of tokens
prompt_tokens: Number of input/prompt tokens
completion_tokens: Number of output/completion tokens
model_name: Name of the model used
Returns:
Total cost in USD
"""
# Model pricing per 1k tokens (as of 2024)
pricing = {
"gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
"gpt-4o": {"input": 0.005, "output": 0.015},
"gpt-4": {"input": 0.03, "output": 0.06},
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
"gpt-3.5-turbo": {"input": 0.0015, "output": 0.002},
"claude-3-opus": {"input": 0.015, "output": 0.075},
"claude-3-sonnet": {"input": 0.003, "output": 0.015},
"claude-3-haiku": {"input": 0.00025, "output": 0.00125},
"claude-3.5-sonnet": {"input": 0.003, "output": 0.015},
"claude-3.5-haiku": {"input": 0.0008, "output": 0.004},
}
# Normalize model name to match pricing keys
model_key = model_name.lower()
if "gpt-5-mini" in model_key:
model_key = "gpt-5-mini"
elif "gpt-4o-mini" in model_key:
model_key = "gpt-4o-mini"
elif "gpt-4o" in model_key:
model_key = "gpt-4o"
elif "gpt-4-turbo" in model_key or "gpt-4-1106" in model_key:
model_key = "gpt-4-turbo"
elif "gpt-4" in model_key:
model_key = "gpt-4"
elif "gpt-3.5" in model_key:
model_key = "gpt-3.5-turbo"
elif "claude-3.5-sonnet" in model_key:
model_key = "claude-3.5-sonnet"
elif "claude-3.5-haiku" in model_key:
model_key = "claude-3.5-haiku"
elif "claude-3-opus" in model_key:
model_key = "claude-3-opus"
elif "claude-3-sonnet" in model_key:
model_key = "claude-3-sonnet"
elif "claude-3-haiku" in model_key:
model_key = "claude-3-haiku"
if model_key not in pricing:
# Default to gpt-5-mini pricing if model not found
model_key = "gpt-5-mini"
rates = pricing[model_key]
# Calculate cost: (tokens / 1000) * rate_per_1k_tokens
input_cost = (prompt_tokens / 1000) * rates["input"]
output_cost = (completion_tokens / 1000) * rates["output"]
return input_cost + output_cost
def process_text(self, text: str) -> Dict[str, Any]:
"""
Process input text using the multi-stage clustering approach.
Args:
text: Input text to process
Returns:
Dictionary with kg_data, metadata, success, and optional error
"""
start_time = time.time()
try:
logger.info(f"process_text called with text length: {len(text)}")
logger.info(f"text first 200 chars: {repr(text[:200])}")
logger.info("Starting clustering crew execution with input_data...")
# Run the crew with proper input mechanism
result = self.clustering_crew.kickoff(inputs={"input_data": text})
logger.info(f"Clustering crew execution completed, result type: {type(result)}")
processing_time = time.time() - start_time
# Extract token usage from crew
token_usage = {
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"total_cost_usd": 0.0,
"model_used": "gpt-4o-mini",
"usage_available": False,
}
try:
if hasattr(self.clustering_crew, "usage_metrics") and self.clustering_crew.usage_metrics:
usage_metrics = self.clustering_crew.usage_metrics
logger.info(f"Found usage metrics: {usage_metrics}")
if isinstance(usage_metrics, dict):
token_usage.update(
{
"total_tokens": usage_metrics.get("total_tokens", 0),
"prompt_tokens": usage_metrics.get("prompt_tokens", 0),
"completion_tokens": usage_metrics.get("completion_tokens", 0),
"total_cost_usd": float(usage_metrics.get("total_cost", 0.0)),
"model_used": usage_metrics.get("model", "gpt-4o-mini"),
"usage_available": True,
}
)
# If cost is 0.0, calculate it manually
if token_usage["total_cost_usd"] == 0.0 and token_usage["total_tokens"] > 0:
calculated_cost = self._calculate_token_cost(
token_usage["total_tokens"], token_usage["prompt_tokens"], token_usage["completion_tokens"], token_usage["model_used"]
)
token_usage["total_cost_usd"] = calculated_cost
logger.info(
f"💰 Calculated cost: ${calculated_cost:.4f} for {token_usage['total_tokens']} tokens ({token_usage['model_used']})"
)
else:
# Handle object-style usage metrics
token_usage.update(
{
"total_tokens": getattr(usage_metrics, "total_tokens", 0),
"prompt_tokens": getattr(usage_metrics, "prompt_tokens", 0),
"completion_tokens": getattr(usage_metrics, "completion_tokens", 0),
"total_cost_usd": float(getattr(usage_metrics, "total_cost", 0.0)),
"model_used": getattr(usage_metrics, "model", "gpt-4o-mini"),
"usage_available": True,
}
)
# If cost is 0.0, calculate it manually
if token_usage["total_cost_usd"] == 0.0 and token_usage["total_tokens"] > 0:
calculated_cost = self._calculate_token_cost(
token_usage["total_tokens"], token_usage["prompt_tokens"], token_usage["completion_tokens"], token_usage["model_used"]
)
token_usage["total_cost_usd"] = calculated_cost
logger.info(
f"💰 Calculated cost: ${calculated_cost:.4f} for {token_usage['total_tokens']} tokens ({token_usage['model_used']})"
)
else:
logger.warning("No usage metrics found in crew")
except Exception as e:
logger.error(f"Error extracting token usage: {e}")
# Extract the knowledge graph from the result
if hasattr(result, "pydantic") and result.pydantic:
kg_data = result.pydantic.dict()
logger.info(
f"Successfully extracted KG with {len(kg_data.get('entities', []))} entities and {len(kg_data.get('relations', []))} relations"
)
# Add processing metadata
if "metadata" not in kg_data:
kg_data["metadata"] = {}
kg_data["metadata"].update(
{
"timestamp": datetime.now().isoformat(),
"processing_info": {
"method": "multi_stage_clustering",
"processing_time_seconds": processing_time,
"processed_at": datetime.now().isoformat(),
"agent_count": 4,
"task_count": 4,
"stages": ["extraction", "entity_clustering", "relationship_clustering", "validation"],
},
"token_usage": token_usage,
}
)
return {"kg_data": kg_data, "metadata": kg_data["metadata"], "token_usage": token_usage, "success": True}
else:
# Handle case where result doesn't have pydantic attribute
logger.warning(f"Result doesn't have pydantic attribute, result type: {type(result)}")
if hasattr(result, "raw"):
logger.info(f"Raw result: {result.raw[:500]}...")
return {
"kg_data": None,
"metadata": {
"timestamp": datetime.now().isoformat(),
"processing_time_seconds": processing_time,
"method": "multi_stage_clustering",
"token_usage": token_usage,
},
"token_usage": token_usage,
"success": False,
"error": f"Failed to extract pydantic result from crew output: {type(result)}",
}
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"Error in clustering method processing: {e}")
# Try to extract token usage even on error
token_usage = {
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"total_cost_usd": 0.0,
"model_used": "gpt-4o-mini",
"usage_available": False,
}
try:
if hasattr(self.clustering_crew, "usage_metrics") and self.clustering_crew.usage_metrics:
usage_metrics = self.clustering_crew.usage_metrics
if isinstance(usage_metrics, dict):
token_usage.update(
{
"total_tokens": usage_metrics.get("total_tokens", 0),
"prompt_tokens": usage_metrics.get("prompt_tokens", 0),
"completion_tokens": usage_metrics.get("completion_tokens", 0),
"total_cost_usd": float(usage_metrics.get("total_cost", 0.0)),
"model_used": usage_metrics.get("model", "gpt-4o-mini"),
"usage_available": True,
}
)
# If cost is 0.0, calculate it manually
if token_usage["total_cost_usd"] == 0.0 and token_usage["total_tokens"] > 0:
calculated_cost = self._calculate_token_cost(
token_usage["total_tokens"], token_usage["prompt_tokens"], token_usage["completion_tokens"], token_usage["model_used"]
)
token_usage["total_cost_usd"] = calculated_cost
logger.info(
f"💰 Calculated cost: ${calculated_cost:.4f} for {token_usage['total_tokens']} tokens ({token_usage['model_used']})"
)
else:
# Handle object-style usage metrics
token_usage.update(
{
"total_tokens": getattr(usage_metrics, "total_tokens", 0),
"prompt_tokens": getattr(usage_metrics, "prompt_tokens", 0),
"completion_tokens": getattr(usage_metrics, "completion_tokens", 0),
"total_cost_usd": float(getattr(usage_metrics, "total_cost", 0.0)),
"model_used": getattr(usage_metrics, "model", "gpt-4o-mini"),
"usage_available": True,
}
)
# If cost is 0.0, calculate it manually
if token_usage["total_cost_usd"] == 0.0 and token_usage["total_tokens"] > 0:
calculated_cost = self._calculate_token_cost(
token_usage["total_tokens"], token_usage["prompt_tokens"], token_usage["completion_tokens"], token_usage["model_used"]
)
token_usage["total_cost_usd"] = calculated_cost
logger.info(
f"💰 Calculated cost: ${calculated_cost:.4f} for {token_usage['total_tokens']} tokens ({token_usage['model_used']})"
)
except Exception as e:
logger.error(f"Error extracting token usage: {e}")
pass
return {
"kg_data": None,
"metadata": {
"timestamp": datetime.now().isoformat(),
"processing_time_seconds": processing_time,
"method": "multi_stage_clustering",
"token_usage": token_usage,
},
"token_usage": token_usage,
"success": False,
"error": str(e),
}
def extract_knowledge_graph(self, trace_data: str) -> Dict[str, Any]:
"""
Extract knowledge graph from trace data using multi-stage clustering.
Args:
trace_data: Input trace data as string
Returns:
Dictionary containing the extracted knowledge graph
"""
logger.info(f"extract_knowledge_graph called with trace_data type: {type(trace_data)}")
logger.info(f"trace_data length: {len(trace_data)}")
logger.info(f"trace_data first 200 chars: {repr(trace_data[:200])}")
# Process the text using our clustering approach
result = self.process_text(trace_data)
if result["success"] and result["kg_data"]:
logger.info("Successfully processed trace data")
return result["kg_data"]
else:
logger.error(f"Failed to process trace data: {result.get('error', 'Unknown error')}")
# Return a minimal structure to avoid breaking the evaluation
return {
"entities": [],
"relations": [],
"system_name": "Failed Extraction",
"system_summary": "Knowledge graph extraction failed",
"metadata": result.get("metadata", {}),
}