Spaces:
Sleeping
Sleeping
| from enum import Enum | |
| from typing import Any, Callable, Dict, List | |
| import networkx as nx | |
| from pydantic.v1 import BaseModel, Field, validator | |
| from swarms.structs.agent import Agent # noqa: F401 | |
| from swarms.utils.loguru_logger import initialize_logger | |
| logger = initialize_logger(log_folder="graph_workflow") | |
| class NodeType(str, Enum): | |
| AGENT: Agent = "agent" | |
| TASK: str = "task" | |
| class Node(BaseModel): | |
| """ | |
| Represents a node in a graph workflow. | |
| Attributes: | |
| id (str): The unique identifier of the node. | |
| type (NodeType): The type of the node. | |
| callable (Callable, optional): The callable associated with the node. Required for task nodes. | |
| agent (Any, optional): The agent associated with the node. | |
| Raises: | |
| ValueError: If the node type is TASK and no callable is provided. | |
| Examples: | |
| >>> node = Node(id="task1", type=NodeType.TASK, callable=sample_task) | |
| >>> node = Node(id="agent1", type=NodeType.AGENT, agent=agent1) | |
| >>> node = Node(id="agent2", type=NodeType.AGENT, agent=agent2) | |
| """ | |
| id: str | |
| type: NodeType | |
| callable: Callable = None | |
| agent: Any = None | |
| def validate_callable(cls, value, values): | |
| if values["type"] == NodeType.TASK and value is None: | |
| raise ValueError("Task nodes must have a callable.") | |
| return value | |
| class Edge(BaseModel): | |
| source: str | |
| target: str | |
| class GraphWorkflow(BaseModel): | |
| """ | |
| Represents a workflow graph. | |
| Attributes: | |
| nodes (Dict[str, Node]): A dictionary of nodes in the graph, where the key is the node ID and the value is the Node object. | |
| edges (List[Edge]): A list of edges in the graph, where each edge is represented by an Edge object. | |
| entry_points (List[str]): A list of node IDs that serve as entry points to the graph. | |
| end_points (List[str]): A list of node IDs that serve as end points of the graph. | |
| graph (nx.DiGraph): A directed graph object from the NetworkX library representing the workflow graph. | |
| """ | |
| nodes: Dict[str, Node] = Field(default_factory=dict) | |
| edges: List[Edge] = Field(default_factory=list) | |
| entry_points: List[str] = Field(default_factory=list) | |
| end_points: List[str] = Field(default_factory=list) | |
| graph: nx.DiGraph = Field( | |
| default_factory=nx.DiGraph, exclude=True | |
| ) | |
| max_loops: int = 1 | |
| class Config: | |
| arbitrary_types_allowed = True | |
| def add_node(self, node: Node): | |
| """ | |
| Adds a node to the workflow graph. | |
| Args: | |
| node (Node): The node object to be added. | |
| Raises: | |
| ValueError: If a node with the same ID already exists in the graph. | |
| """ | |
| try: | |
| if node.id in self.nodes: | |
| raise ValueError( | |
| f"Node with id {node.id} already exists." | |
| ) | |
| self.nodes[node.id] = node | |
| self.graph.add_node( | |
| node.id, | |
| type=node.type, | |
| callable=node.callable, | |
| agent=node.agent, | |
| ) | |
| except Exception as e: | |
| logger.info(f"Error in adding node to the workflow: {e}") | |
| raise e | |
| def add_edge(self, edge: Edge): | |
| """ | |
| Adds an edge to the workflow graph. | |
| Args: | |
| edge (Edge): The edge object to be added. | |
| Raises: | |
| ValueError: If either the source or target node of the edge does not exist in the graph. | |
| """ | |
| if ( | |
| edge.source not in self.nodes | |
| or edge.target not in self.nodes | |
| ): | |
| raise ValueError( | |
| "Both source and target nodes must exist before adding an edge." | |
| ) | |
| self.edges.append(edge) | |
| self.graph.add_edge(edge.source, edge.target) | |
| def set_entry_points(self, entry_points: List[str]): | |
| """ | |
| Sets the entry points of the workflow graph. | |
| Args: | |
| entry_points (List[str]): A list of node IDs to be set as entry points. | |
| Raises: | |
| ValueError: If any of the specified node IDs do not exist in the graph. | |
| """ | |
| for node_id in entry_points: | |
| if node_id not in self.nodes: | |
| raise ValueError( | |
| f"Node with id {node_id} does not exist." | |
| ) | |
| self.entry_points = entry_points | |
| def set_end_points(self, end_points: List[str]): | |
| """ | |
| Sets the end points of the workflow graph. | |
| Args: | |
| end_points (List[str]): A list of node IDs to be set as end points. | |
| Raises: | |
| ValueError: If any of the specified node IDs do not exist in the graph. | |
| """ | |
| for node_id in end_points: | |
| if node_id not in self.nodes: | |
| raise ValueError( | |
| f"Node with id {node_id} does not exist." | |
| ) | |
| self.end_points = end_points | |
| def visualize(self) -> str: | |
| """ | |
| Generates a string representation of the workflow graph in the Mermaid syntax. | |
| Returns: | |
| str: The Mermaid string representation of the workflow graph. | |
| """ | |
| mermaid_str = "graph TD\n" | |
| for node_id, node in self.nodes.items(): | |
| mermaid_str += f" {node_id}[{node_id}]\n" | |
| for edge in self.edges: | |
| mermaid_str += f" {edge.source} --> {edge.target}\n" | |
| return mermaid_str | |
| def run( | |
| self, task: str = None, *args, **kwargs | |
| ) -> Dict[str, Any]: | |
| """ | |
| Function to run the workflow graph. | |
| Args: | |
| task (str): The task to be executed by the workflow. | |
| *args: Variable length argument list. | |
| **kwargs: Arbitrary keyword arguments. | |
| Returns: | |
| Dict[str, Any]: A dictionary containing the results of the execution. | |
| Raises: | |
| ValueError: If no entry points or end points are defined in the graph. | |
| """ | |
| try: | |
| loop = 0 | |
| while loop < self.max_loops: | |
| # Ensure all nodes and edges are valid | |
| if not self.entry_points: | |
| raise ValueError( | |
| "At least one entry point must be defined." | |
| ) | |
| if not self.end_points: | |
| raise ValueError( | |
| "At least one end point must be defined." | |
| ) | |
| # Perform a topological sort of the graph to ensure proper execution order | |
| sorted_nodes = list(nx.topological_sort(self.graph)) | |
| # Initialize execution state | |
| execution_results = {} | |
| for node_id in sorted_nodes: | |
| node = self.nodes[node_id] | |
| if node.type == NodeType.TASK: | |
| print(f"Executing task: {node_id}") | |
| result = node.callable() | |
| elif node.type == NodeType.AGENT: | |
| print(f"Executing agent: {node_id}") | |
| result = node.agent.run(task, *args, **kwargs) | |
| execution_results[node_id] = result | |
| loop += 1 | |
| return execution_results | |
| except Exception as e: | |
| logger.info(f"Error in running the workflow: {e}") | |
| raise e | |
| # # Example usage | |
| # if __name__ == "__main__": | |
| # from swarms import Agent | |
| # import os | |
| # from dotenv import load_dotenv | |
| # load_dotenv() | |
| # api_key = os.environ.get("OPENAI_API_KEY") | |
| # llm = OpenAIChat( | |
| # temperature=0.5, openai_api_key=api_key, max_tokens=4000 | |
| # ) | |
| # agent1 = Agent(llm=llm, max_loops=1, autosave=True, dashboard=True) | |
| # agent2 = Agent(llm=llm, max_loops=1, autosave=True, dashboard=True) | |
| # def sample_task(): | |
| # print("Running sample task") | |
| # return "Task completed" | |
| # wf_graph = GraphWorkflow() | |
| # wf_graph.add_node(Node(id="agent1", type=NodeType.AGENT, agent=agent1)) | |
| # wf_graph.add_node(Node(id="agent2", type=NodeType.AGENT, agent=agent2)) | |
| # wf_graph.add_node( | |
| # Node(id="task1", type=NodeType.TASK, callable=sample_task) | |
| # ) | |
| # wf_graph.add_edge(Edge(source="agent1", target="task1")) | |
| # wf_graph.add_edge(Edge(source="agent2", target="task1")) | |
| # wf_graph.set_entry_points(["agent1", "agent2"]) | |
| # wf_graph.set_end_points(["task1"]) | |
| # print(wf_graph.visualize()) | |
| # # Run the workflow | |
| # results = wf_graph.run() | |
| # print("Execution results:", results) | |