Final_Assignment_Template / agent_builder.py
24Arys11's picture
Trying something new - an agent builder...
bea2d8c
from pathlib import Path
import yaml
from typing import Dict, List, Any
class AgentBuilder:
@staticmethod
def initialize(agent_name: str):
"""Create agent folder and an empty 'yaml' file
Args:
agent_name: Name of the agent to create
"""
# Create base agents directory if it doesn't exist
agents_base_dir = Path("agents")
agents_base_dir.mkdir(exist_ok=True)
# Create agent-specific directory
agent_dir = agents_base_dir / agent_name
agent_dir.mkdir(exist_ok=True)
# Create YAML file with initial content
yaml_file = agent_dir / f"design.yaml"
# Initial YAML content with nodes list and example node
initial_content = {
"nodes": [
{
"name": "START",
"connections": ["example_node"],
"description": "This is the mandatory initial node `START` !"
},
{
"name": "example_node",
"connections": [],
"description": "This is an example node"
}
]
}
# Write the YAML content to the file
with open(yaml_file, "w") as f:
yaml.dump(initial_content, f, default_flow_style=False, sort_keys=False)
@classmethod
def setup(cls, agent_name: str):
"""Create the graph and the test python files as well as a 'puml' diagram
Args:
agent_name: Name of the agent to set up
"""
design_data = cls._validate_design(agent_name)
cls._create_graph_file(agent_name, design_data)
cls._create_test_file(agent_name)
cls._create_puml_file(agent_name, design_data)
@classmethod
def _validate_design(cls, agent_name: str) -> Dict[str, List[Dict[str, Any]]]:
"""Validate the design.yaml file structure
Args:
agent_name: Name of the agent to validate
Returns:
The parsed design data if valid
Raises:
ValueError: If design file is invalid or missing required elements
"""
yaml_path = Path(f"agents/{agent_name}/design.yaml")
if not yaml_path.exists():
raise ValueError(f"Design file not found at {yaml_path}")
with open(yaml_path, 'r') as f:
design_data = yaml.safe_load(f)
# Check if nodes list exists
if not design_data or 'nodes' not in design_data or not isinstance(design_data['nodes'], list):
raise ValueError("Design file must contain a 'nodes' list")
# Check if START node is defined
start_node_exists = any(node.get('name') == "START" for node in design_data['nodes'])
if not start_node_exists:
raise ValueError("Design file must contain a 'START' node")
# Validate each node
for i, node in enumerate(design_data['nodes']):
if not isinstance(node, dict):
raise ValueError(f"Node at index {i} must be a dictionary")
if 'name' not in node:
raise ValueError(f"Node at index {i} is missing a 'name' field")
if 'description' not in node:
raise ValueError(f"Node '{node.get('name', f'at index {i}')}' is missing a 'description' field")
if 'connections' not in node or not isinstance(node['connections'], list):
raise ValueError(f"Node '{node.get('name')}' must have a 'connections' list")
return design_data
@classmethod
def _create_graph_file(cls, agent_name: str, design_data: Dict[str, List[Dict[str, Any]]]):
"""Create the graph.py file with the necessary classes
Args:
agent_name: Name of the agent
design_data: The validated design data
"""
nodes = design_data['nodes']
# Prepare node methods and conditional edge methods
node_methods = []
edge_methods = []
# Generate node method for each node
for node in nodes:
node_name = node['name']
node_desc = node['description']
node_method = f'''
def {node_name}_node(self, state):
"""
{node_desc}
"""
# TODO: To implement...
pass
'''
if node_name != "START":
node_methods.append(node_method)
# Check if this node has more than one connection (needs conditional edge)
if len(node['connections']) > 1:
connections_str = ", ".join([f'"{conn}"' for conn in node['connections']])
edge_method = f'''
def {node_name}_edge(self, state):
"""
Conditional edge for {node_name} node.
Returns one of: {connections_str}
"""
# TODO: To implement...
pass
'''
edge_methods.append(edge_method)
# Build the file content
file_content = f'''from typing import Dict, Any
from langgraph.graph import StateGraph, END, START
from langgraph.graph.state import CompiledStateGraph
class State:
"""
State class for the agent graph.
"""
# TODO: Define state structure
pass
class Nodes:
"""
Collection of node functions for the agent graph.
"""
{"".join(node_methods)}
class Edges:
"""
Collection of conditional edge functions for the agent graph.
"""
{"".join(edge_methods)}
class GraphBuilder:
def __init__(self):
"""
Initializes the GraphBuilder.
"""
self.nodes = Nodes()
self.edges = Edges()
# TODO: Implement the desired constructor.
pass
def build_agent_graph(self) -> CompiledStateGraph:
"""Build and return the agent graph."""
graph = StateGraph(State)
# Add all nodes
{cls._generate_add_nodes_code(nodes)}
# Add edges
{cls._generate_regular_edges_code(nodes)}
{cls._generate_conditional_edges_code(nodes)}
return graph.compile()
'''
# Write to file
graph_file_path = Path(f"agents/{agent_name}/graph.py")
with open(graph_file_path, 'w') as f:
f.write(file_content)
@staticmethod
def _generate_add_nodes_code(nodes):
"""Generate code for adding nodes to the graph"""
code_lines = []
for node in nodes:
if node["name"] != "START":
code_lines.append(f' graph.add_node("{node["name"]}", self.nodes.{node["name"]}_node)')
return "\n".join(code_lines)
@staticmethod
def _generate_conditional_edges_code(nodes):
"""Generate code for adding conditional edges to the graph"""
code_lines = []
for node in nodes:
if len(node['connections']) > 1:
destinations = ", ".join([f'{conn}: {conn}' if conn in ["START", "END"] else f'"{conn}": "{conn}"' for conn in node['connections']])
code_lines.append(f''' graph.add_conditional_edges(
"{node["name"]}",
self.edges.{node["name"]}_edge,
{{
{destinations}
}}
)''')
return "\n".join(code_lines) if code_lines else ""
@staticmethod
def _generate_regular_edges_code(nodes):
"""Generate code for adding regular edges to the graph"""
code_lines = []
for node in nodes:
if len(node['connections']) == 1:
start_key = node["name"] if node["name"] in ["START", "END"] else f'"{node["name"]}"'
end_key = node["connections"][0] if node["connections"][0] in ["START", "END"] else f'"{node["connections"][0]}"'
code_lines.append(f' graph.add_edge({start_key}, {end_key})')
return "\n".join(code_lines) if code_lines else ""
@staticmethod
def _create_test_file(agent_name: str):
"""Create the test.py file
Args:
agent_name: Name of the agent
"""
test_file_content = f'''# Test file for {agent_name} agent
from graph import GraphBuilder
def test_agent():
"""
Test the {agent_name} agent functionality.
"""
# Create the graph
builder = GraphBuilder()
graph = builder.build_agent_graph()
# TODO: Add test code here
print("Testing {agent_name} agent...")
# Example test
# result = graph.invoke({{"input": "Test input"}})
# print(f"Result: {{result}}")
if __name__ == "__main__":
test_agent()
'''
# Write to file
test_file_path = Path(f"agents/{agent_name}/test.py")
with open(test_file_path, 'w') as f:
f.write(test_file_content)
@staticmethod
def _create_puml_file(agent_name: str, design_data: Dict[str, List[Dict[str, Any]]]):
"""Create the design.puml file for diagram visualization
Args:
agent_name: Name of the agent
design_data: The validated design data
"""
nodes = design_data['nodes']
# Start the PlantUML content
puml_content = f'''@startuml {agent_name}
!define NOT_IMPLEMENTED_NODE_COLOR #IndianRed
!define IMPLEMENTED_NODE_COLOR #Gold
!define TESTED_NODE_COLOR #LawnGreen
!define TERMINAL_NODE_COLOR #DodgerBlue
'''
# Add node descriptions
for node in nodes:
puml_content += f'node {node["name"]} {node["status"]}_NODE_COLOR[\n {node["description"]}\n]\n\n'
puml_content += f'node END TERMINAL_NODE_COLOR[\n This is the final Node !\n]\n\n'
# Add connections
for node in nodes:
node_name = node["name"]
for connection in node["connections"]:
if connection == "END":
puml_content += f'{node_name} --> END\n'
else:
puml_content += f'{node_name} --> {connection}\n'
# End the PlantUML content
puml_content += '\n@enduml'
# Write to file
puml_file_path = Path(f"agents/{agent_name}/design.puml")
with open(puml_file_path, 'w') as f:
f.write(puml_content)
@staticmethod
def validate(agent_name: str):
pass
if __name__ == "__main__":
# AgentBuilder.initialize("universal_solver")
AgentBuilder.setup("universal_solver")