Spaces:
Sleeping
Sleeping
from enum import Enum | |
from typing import Any, Callable, Dict, List | |
import networkx as nx | |
from pydantic.v1 import BaseModel, Field, validator | |
from swarms.structs.agent import Agent # noqa: F401 | |
from swarms.utils.loguru_logger import initialize_logger | |
logger = initialize_logger(log_folder="graph_workflow") | |
class NodeType(str, Enum): | |
AGENT: Agent = "agent" | |
TASK: str = "task" | |
class Node(BaseModel): | |
""" | |
Represents a node in a graph workflow. | |
Attributes: | |
id (str): The unique identifier of the node. | |
type (NodeType): The type of the node. | |
callable (Callable, optional): The callable associated with the node. Required for task nodes. | |
agent (Any, optional): The agent associated with the node. | |
Raises: | |
ValueError: If the node type is TASK and no callable is provided. | |
Examples: | |
>>> node = Node(id="task1", type=NodeType.TASK, callable=sample_task) | |
>>> node = Node(id="agent1", type=NodeType.AGENT, agent=agent1) | |
>>> node = Node(id="agent2", type=NodeType.AGENT, agent=agent2) | |
""" | |
id: str | |
type: NodeType | |
callable: Callable = None | |
agent: Any = None | |
def validate_callable(cls, value, values): | |
if values["type"] == NodeType.TASK and value is None: | |
raise ValueError("Task nodes must have a callable.") | |
return value | |
class Edge(BaseModel): | |
source: str | |
target: str | |
class GraphWorkflow(BaseModel): | |
""" | |
Represents a workflow graph. | |
Attributes: | |
nodes (Dict[str, Node]): A dictionary of nodes in the graph, where the key is the node ID and the value is the Node object. | |
edges (List[Edge]): A list of edges in the graph, where each edge is represented by an Edge object. | |
entry_points (List[str]): A list of node IDs that serve as entry points to the graph. | |
end_points (List[str]): A list of node IDs that serve as end points of the graph. | |
graph (nx.DiGraph): A directed graph object from the NetworkX library representing the workflow graph. | |
""" | |
nodes: Dict[str, Node] = Field(default_factory=dict) | |
edges: List[Edge] = Field(default_factory=list) | |
entry_points: List[str] = Field(default_factory=list) | |
end_points: List[str] = Field(default_factory=list) | |
graph: nx.DiGraph = Field( | |
default_factory=nx.DiGraph, exclude=True | |
) | |
max_loops: int = 1 | |
class Config: | |
arbitrary_types_allowed = True | |
def add_node(self, node: Node): | |
""" | |
Adds a node to the workflow graph. | |
Args: | |
node (Node): The node object to be added. | |
Raises: | |
ValueError: If a node with the same ID already exists in the graph. | |
""" | |
try: | |
if node.id in self.nodes: | |
raise ValueError( | |
f"Node with id {node.id} already exists." | |
) | |
self.nodes[node.id] = node | |
self.graph.add_node( | |
node.id, | |
type=node.type, | |
callable=node.callable, | |
agent=node.agent, | |
) | |
except Exception as e: | |
logger.info(f"Error in adding node to the workflow: {e}") | |
raise e | |
def add_edge(self, edge: Edge): | |
""" | |
Adds an edge to the workflow graph. | |
Args: | |
edge (Edge): The edge object to be added. | |
Raises: | |
ValueError: If either the source or target node of the edge does not exist in the graph. | |
""" | |
if ( | |
edge.source not in self.nodes | |
or edge.target not in self.nodes | |
): | |
raise ValueError( | |
"Both source and target nodes must exist before adding an edge." | |
) | |
self.edges.append(edge) | |
self.graph.add_edge(edge.source, edge.target) | |
def set_entry_points(self, entry_points: List[str]): | |
""" | |
Sets the entry points of the workflow graph. | |
Args: | |
entry_points (List[str]): A list of node IDs to be set as entry points. | |
Raises: | |
ValueError: If any of the specified node IDs do not exist in the graph. | |
""" | |
for node_id in entry_points: | |
if node_id not in self.nodes: | |
raise ValueError( | |
f"Node with id {node_id} does not exist." | |
) | |
self.entry_points = entry_points | |
def set_end_points(self, end_points: List[str]): | |
""" | |
Sets the end points of the workflow graph. | |
Args: | |
end_points (List[str]): A list of node IDs to be set as end points. | |
Raises: | |
ValueError: If any of the specified node IDs do not exist in the graph. | |
""" | |
for node_id in end_points: | |
if node_id not in self.nodes: | |
raise ValueError( | |
f"Node with id {node_id} does not exist." | |
) | |
self.end_points = end_points | |
def visualize(self) -> str: | |
""" | |
Generates a string representation of the workflow graph in the Mermaid syntax. | |
Returns: | |
str: The Mermaid string representation of the workflow graph. | |
""" | |
mermaid_str = "graph TD\n" | |
for node_id, node in self.nodes.items(): | |
mermaid_str += f" {node_id}[{node_id}]\n" | |
for edge in self.edges: | |
mermaid_str += f" {edge.source} --> {edge.target}\n" | |
return mermaid_str | |
def run( | |
self, task: str = None, *args, **kwargs | |
) -> Dict[str, Any]: | |
""" | |
Function to run the workflow graph. | |
Args: | |
task (str): The task to be executed by the workflow. | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
Dict[str, Any]: A dictionary containing the results of the execution. | |
Raises: | |
ValueError: If no entry points or end points are defined in the graph. | |
""" | |
try: | |
loop = 0 | |
while loop < self.max_loops: | |
# Ensure all nodes and edges are valid | |
if not self.entry_points: | |
raise ValueError( | |
"At least one entry point must be defined." | |
) | |
if not self.end_points: | |
raise ValueError( | |
"At least one end point must be defined." | |
) | |
# Perform a topological sort of the graph to ensure proper execution order | |
sorted_nodes = list(nx.topological_sort(self.graph)) | |
# Initialize execution state | |
execution_results = {} | |
for node_id in sorted_nodes: | |
node = self.nodes[node_id] | |
if node.type == NodeType.TASK: | |
print(f"Executing task: {node_id}") | |
result = node.callable() | |
elif node.type == NodeType.AGENT: | |
print(f"Executing agent: {node_id}") | |
result = node.agent.run(task, *args, **kwargs) | |
execution_results[node_id] = result | |
loop += 1 | |
return execution_results | |
except Exception as e: | |
logger.info(f"Error in running the workflow: {e}") | |
raise e | |
# # Example usage | |
# if __name__ == "__main__": | |
# from swarms import Agent | |
# import os | |
# from dotenv import load_dotenv | |
# load_dotenv() | |
# api_key = os.environ.get("OPENAI_API_KEY") | |
# llm = OpenAIChat( | |
# temperature=0.5, openai_api_key=api_key, max_tokens=4000 | |
# ) | |
# agent1 = Agent(llm=llm, max_loops=1, autosave=True, dashboard=True) | |
# agent2 = Agent(llm=llm, max_loops=1, autosave=True, dashboard=True) | |
# def sample_task(): | |
# print("Running sample task") | |
# return "Task completed" | |
# wf_graph = GraphWorkflow() | |
# wf_graph.add_node(Node(id="agent1", type=NodeType.AGENT, agent=agent1)) | |
# wf_graph.add_node(Node(id="agent2", type=NodeType.AGENT, agent=agent2)) | |
# wf_graph.add_node( | |
# Node(id="task1", type=NodeType.TASK, callable=sample_task) | |
# ) | |
# wf_graph.add_edge(Edge(source="agent1", target="task1")) | |
# wf_graph.add_edge(Edge(source="agent2", target="task1")) | |
# wf_graph.set_entry_points(["agent1", "agent2"]) | |
# wf_graph.set_end_points(["task1"]) | |
# print(wf_graph.visualize()) | |
# # Run the workflow | |
# results = wf_graph.run() | |
# print("Execution results:", results) | |