SPARKNET / src /agents /planner_agent.py
MHamdan's picture
Initial commit: SPARKNET framework
a9dc537
"""
PlannerAgent for SPARKNET - LangChain Version
Breaks down complex VISTA scenarios into executable workflows
Uses LangChain chains for structured task decomposition
"""
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from loguru import logger
import json
import networkx as nx
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.messages import HumanMessage, SystemMessage
from .base_agent import BaseAgent, Task, Message
from ..llm.langchain_ollama_client import LangChainOllamaClient
from ..workflow.langgraph_state import SubTask as SubTaskModel, TaskStatus
# Pydantic model for planning output
class TaskDecomposition(BaseModel):
"""Structured output from planning chain"""
subtasks: List[Dict[str, Any]] = Field(description="List of subtasks with dependencies")
reasoning: str = Field(description="Explanation of the planning strategy")
estimated_total_duration: float = Field(description="Total estimated duration in seconds")
@dataclass
class TaskGraph:
"""Directed acyclic graph of tasks with dependencies."""
subtasks: Dict[str, SubTaskModel] = field(default_factory=dict)
graph: nx.DiGraph = field(default_factory=nx.DiGraph)
def add_subtask(self, subtask: SubTaskModel):
"""Add a subtask to the graph."""
self.subtasks[subtask.id] = subtask
self.graph.add_node(subtask.id, task=subtask)
# Add edges for dependencies
for dep_id in subtask.dependencies:
if dep_id in self.subtasks:
self.graph.add_edge(dep_id, subtask.id)
def get_execution_order(self) -> List[List[str]]:
"""
Get tasks in execution order (topological sort).
Returns list of lists - inner lists can be executed in parallel.
"""
try:
generations = list(nx.topological_generations(self.graph))
return generations
except nx.NetworkXError as e:
logger.error(f"Error in topological sort: {e}")
return []
def validate(self) -> bool:
"""Validate graph has no cycles."""
return nx.is_directed_acyclic_graph(self.graph)
class PlannerAgent(BaseAgent):
"""
Agent specialized in task decomposition and workflow planning.
Uses LangChain chains with qwen2.5:14b for complex reasoning.
"""
# Scenario templates for common VISTA workflows
SCENARIO_TEMPLATES = {
'patent_wakeup': {
'description': 'Analyze dormant patent and create valorization roadmap',
'stages': [
{
'name': 'document_analysis',
'agent': 'DocumentAnalysisAgent',
'description': 'Extract and analyze patent content',
'dependencies': [],
},
{
'name': 'market_analysis',
'agent': 'MarketAnalysisAgent',
'description': 'Identify market opportunities for patent',
'dependencies': ['document_analysis'],
},
{
'name': 'matchmaking',
'agent': 'MatchmakingAgent',
'description': 'Match patent with potential licensees',
'dependencies': ['document_analysis', 'market_analysis'],
},
{
'name': 'outreach',
'agent': 'OutreachAgent',
'description': 'Generate valorization brief and outreach materials',
'dependencies': ['matchmaking'],
},
],
},
'agreement_safety': {
'description': 'Review legal agreement for risks and compliance',
'stages': [
{
'name': 'document_parsing',
'agent': 'LegalAnalysisAgent',
'description': 'Parse agreement and extract clauses',
'dependencies': [],
},
{
'name': 'compliance_check',
'agent': 'ComplianceAgent',
'description': 'Check GDPR and Law 25 compliance',
'dependencies': ['document_parsing'],
},
{
'name': 'risk_assessment',
'agent': 'RiskAssessmentAgent',
'description': 'Identify problematic clauses and risks',
'dependencies': ['document_parsing'],
},
{
'name': 'recommendations',
'agent': 'RecommendationAgent',
'description': 'Generate improvement suggestions',
'dependencies': ['compliance_check', 'risk_assessment'],
},
],
},
'partner_matching': {
'description': 'Match stakeholders based on complementary capabilities',
'stages': [
{
'name': 'profiling',
'agent': 'ProfilingAgent',
'description': 'Extract stakeholder capabilities and needs',
'dependencies': [],
},
{
'name': 'semantic_matching',
'agent': 'SemanticMatchingAgent',
'description': 'Find complementary partners using embeddings',
'dependencies': ['profiling'],
},
{
'name': 'network_analysis',
'agent': 'NetworkAnalysisAgent',
'description': 'Identify strategic network connections',
'dependencies': ['profiling'],
},
{
'name': 'facilitation',
'agent': 'ConnectionFacilitatorAgent',
'description': 'Generate introduction materials',
'dependencies': ['semantic_matching', 'network_analysis'],
},
],
},
}
def __init__(
self,
llm_client: LangChainOllamaClient,
memory_agent: Optional['MemoryAgent'] = None,
temperature: float = 0.7,
):
"""
Initialize PlannerAgent with LangChain client.
Args:
llm_client: LangChain Ollama client
memory_agent: Optional memory agent for context
temperature: LLM temperature for planning
"""
self.llm_client = llm_client
self.memory_agent = memory_agent
self.temperature = temperature
# Create planning chains
self.planning_chain = self._create_planning_chain()
self.refinement_chain = self._create_refinement_chain()
# Store for backward compatibility
self.name = "PlannerAgent"
self.description = "Task decomposition and workflow planning"
logger.info(f"Initialized PlannerAgent with LangChain (complexity: complex)")
def _create_planning_chain(self):
"""
Create LangChain chain for task decomposition.
Returns:
Runnable chain: prompt | llm | parser
"""
system_template = """You are a strategic planning agent for research valorization tasks.
Your role is to:
1. Analyze complex tasks and break them into manageable subtasks
2. Identify dependencies between subtasks
3. Assign appropriate agents to each subtask
4. Estimate task complexity and duration
5. Create optimal execution plans
Available agent types:
- ExecutorAgent: General task execution
- DocumentAnalysisAgent: Analyze patents and documents
- MarketAnalysisAgent: Market research and opportunity identification
- MatchmakingAgent: Stakeholder matching and connections
- OutreachAgent: Generate outreach materials and briefs
- LegalAnalysisAgent: Legal document analysis
- ComplianceAgent: Compliance checking (GDPR, Law 25)
- RiskAssessmentAgent: Risk identification
- ProfilingAgent: Stakeholder profiling
- SemanticMatchingAgent: Semantic similarity matching
- NetworkAnalysisAgent: Network and relationship analysis
Output your plan as a structured JSON object with:
- subtasks: List of subtask objects with id, description, agent_type, dependencies, estimated_duration, priority
- reasoning: Your strategic reasoning for this decomposition
- estimated_total_duration: Total estimated time in seconds"""
human_template = """Given the following task, create a detailed execution plan:
Task: {task_description}
{context_section}
Break this down into specific subtasks. For each subtask:
- Give it a unique ID (use snake_case)
- Describe what needs to be done
- Specify which agent type should handle it
- List any dependencies (IDs of tasks that must complete first)
- Estimate duration in seconds
- Set priority (1=highest)
Think step-by-step about:
- What is the ultimate goal?
- What information is needed?
- What are the logical stages?
- Which subtasks can run in parallel?
- What are the critical dependencies?
Output JSON only."""
prompt = ChatPromptTemplate.from_messages([
("system", system_template),
("human", human_template)
])
# Use complex model for planning
llm = self.llm_client.get_llm(complexity="complex", temperature=self.temperature)
# JSON output parser
parser = JsonOutputParser(pydantic_object=TaskDecomposition)
# Create chain
chain = prompt | llm | parser
return chain
def _create_refinement_chain(self):
"""
Create LangChain chain for replanning based on feedback.
Returns:
Runnable chain for refinement
"""
system_template = """You are refining an existing task plan based on feedback.
Your role is to:
1. Review the original plan and feedback
2. Identify what went wrong or could be improved
3. Create an improved plan that addresses the issues
4. Maintain successful elements from the original plan
Be thoughtful about what to change and what to keep."""
human_template = """Refine the following plan based on feedback:
Original Task: {task_description}
Original Plan:
{original_plan}
Feedback from execution:
{feedback}
Issues encountered:
{issues}
Create an improved plan that addresses these issues while maintaining what worked well.
Output JSON in the same format as before."""
prompt = ChatPromptTemplate.from_messages([
("system", system_template),
("human", human_template)
])
llm = self.llm_client.get_llm(complexity="complex", temperature=self.temperature)
parser = JsonOutputParser(pydantic_object=TaskDecomposition)
chain = prompt | llm | parser
return chain
async def process_task(self, task: Task) -> Task:
"""
Process planning task by decomposing into workflow.
Args:
task: High-level task to plan
Returns:
Updated task with plan in result
"""
logger.info(f"PlannerAgent planning task: {task.id}")
task.status = "in_progress"
try:
# Check if this is a known scenario
scenario = task.metadata.get('scenario') if task.metadata else None
if scenario and scenario in self.SCENARIO_TEMPLATES:
# Use template-based planning
logger.info(f"Using template for scenario: {scenario}")
task_graph = await self._plan_from_template(task, scenario)
else:
# Use LangChain-based planning for custom tasks
logger.info("Using LangChain planning for custom task")
task_graph = await self._plan_with_langchain(task)
# Validate the graph
if not task_graph.validate():
raise ValueError("Generated task graph contains cycles!")
# Store plan in task result
task.result = {
'task_graph': task_graph,
'execution_order': task_graph.get_execution_order(),
'total_subtasks': len(task_graph.subtasks),
}
task.status = "completed"
logger.info(f"Planning completed: {len(task_graph.subtasks)} subtasks")
except Exception as e:
logger.error(f"Planning failed: {e}")
task.status = "failed"
task.error = str(e)
return task
async def _plan_from_template(self, task: Task, scenario: str) -> TaskGraph:
"""
Create task graph from scenario template.
Args:
task: Original task
scenario: Scenario identifier
Returns:
TaskGraph based on template
"""
template = self.SCENARIO_TEMPLATES[scenario]
task_graph = TaskGraph()
# Get task parameters
params = task.metadata.get('parameters', {}) if task.metadata else {}
# Create subtasks from template stages
for i, stage in enumerate(template['stages']):
subtask = SubTaskModel(
id=f"{task.id}_{stage['name']}",
description=stage['description'],
agent_type=stage['agent'],
dependencies=[f"{task.id}_{dep}" for dep in stage['dependencies']],
estimated_duration=30.0,
priority=i + 1,
parameters=params,
status=TaskStatus.PENDING
)
task_graph.add_subtask(subtask)
logger.debug(f"Created task graph with {len(task_graph.subtasks)} subtasks from template")
return task_graph
async def _plan_with_langchain(self, task: Task, context: Optional[List[Any]] = None) -> TaskGraph:
"""
Create task graph using LangChain planning chain.
Args:
task: Original task
context: Optional context from memory
Returns:
TaskGraph generated by LangChain
"""
# Prepare context section
context_section = ""
if context and len(context) > 0:
context_section = "Relevant past experiences:\n"
for i, ctx in enumerate(context[:3], 1): # Top 3 contexts
context_section += f"{i}. {ctx.page_content[:200]}...\n"
# Invoke planning chain
try:
result = await self.planning_chain.ainvoke({
"task_description": task.description,
"context_section": context_section
})
# Parse result into TaskGraph
task_graph = TaskGraph()
for subtask_data in result.get('subtasks', []):
subtask = SubTaskModel(
id=f"{task.id}_{subtask_data.get('id', f'subtask_{len(task_graph.subtasks)}')}",
description=subtask_data.get('description', ''),
agent_type=subtask_data.get('agent_type', 'ExecutorAgent'),
dependencies=[f"{task.id}_{dep}" for dep in subtask_data.get('dependencies', [])],
estimated_duration=subtask_data.get('estimated_duration', 30.0),
priority=subtask_data.get('priority', 0),
parameters=subtask_data.get('parameters', {}),
status=TaskStatus.PENDING
)
task_graph.add_subtask(subtask)
logger.debug(f"Created task graph with {len(task_graph.subtasks)} subtasks from LangChain")
return task_graph
except Exception as e:
logger.error(f"Failed to parse LangChain planning response: {e}")
raise ValueError(f"Failed to generate plan: {e}")
async def decompose_task(
self,
task_description: str,
scenario: Optional[str] = None,
context: Optional[List[Any]] = None
) -> TaskGraph:
"""
Decompose a high-level task into subtasks.
Args:
task_description: Natural language description
scenario: Optional scenario identifier
context: Optional context from memory
Returns:
TaskGraph with subtasks and dependencies
"""
# Create a task object
task = Task(
id=f"plan_{hash(task_description) % 10000}",
description=task_description,
metadata={'scenario': scenario} if scenario else {},
)
# Process with planning logic
result_task = await self.process_task(task)
if result_task.status == "completed" and result_task.result:
return result_task.result['task_graph']
else:
raise RuntimeError(f"Planning failed: {result_task.error}")
async def adapt_plan(
self,
task_graph: TaskGraph,
feedback: str,
issues: List[str]
) -> TaskGraph:
"""
Adapt an existing plan based on execution feedback.
Args:
task_graph: Original task graph
feedback: Feedback from execution
issues: List of issues encountered
Returns:
Updated task graph
"""
logger.info("Adapting plan based on feedback")
# Convert task graph to dict for refinement
original_plan = {
"subtasks": [
{
"id": st.id,
"description": st.description,
"agent_type": st.agent_type,
"dependencies": st.dependencies
}
for st in task_graph.subtasks.values()
]
}
try:
# Invoke refinement chain
result = await self.refinement_chain.ainvoke({
"task_description": "Refine task decomposition",
"original_plan": json.dumps(original_plan, indent=2),
"feedback": feedback,
"issues": "\n".join(f"- {issue}" for issue in issues)
})
# Create new task graph from refined plan
new_task_graph = TaskGraph()
for subtask_data in result.get('subtasks', []):
subtask = SubTaskModel(
id=subtask_data.get('id', f'subtask_{len(new_task_graph.subtasks)}'),
description=subtask_data.get('description', ''),
agent_type=subtask_data.get('agent_type', 'ExecutorAgent'),
dependencies=subtask_data.get('dependencies', []),
estimated_duration=subtask_data.get('estimated_duration', 30.0),
priority=subtask_data.get('priority', 0),
parameters=subtask_data.get('parameters', {}),
status=TaskStatus.PENDING
)
new_task_graph.add_subtask(subtask)
logger.info(f"Plan adapted: {len(new_task_graph.subtasks)} subtasks")
return new_task_graph
except Exception as e:
logger.error(f"Plan adaptation failed: {e}, returning original plan")
return task_graph
def get_parallel_tasks(self, task_graph: TaskGraph) -> List[List[SubTaskModel]]:
"""
Get tasks that can be executed in parallel.
Args:
task_graph: Task graph
Returns:
List of parallel task groups
"""
execution_order = task_graph.get_execution_order()
parallel_groups = []
for task_ids in execution_order:
group = [task_graph.subtasks[task_id] for task_id in task_ids]
parallel_groups.append(group)
return parallel_groups