Spaces:
Sleeping
Sleeping
import uuid | |
from collections import Counter | |
from datetime import datetime | |
from typing import Any, List, Optional | |
from pydantic import BaseModel, Field | |
from swarms.structs.agent import Agent | |
from swarms.utils.loguru_logger import initialize_logger | |
from swarms.utils.auto_download_check_packages import ( | |
auto_check_and_download_package, | |
) | |
from swarms.structs.conversation import Conversation | |
logger = initialize_logger(log_folder="tree_swarm") | |
# Pydantic Models for Logging | |
class AgentLogInput(BaseModel): | |
log_id: str = Field( | |
default_factory=lambda: str(uuid.uuid4()), alias="id" | |
) | |
agent_name: str | |
task: str | |
timestamp: datetime = Field(default_factory=datetime.utcnow) | |
class AgentLogOutput(BaseModel): | |
log_id: str = Field( | |
default_factory=lambda: str(uuid.uuid4()), alias="id" | |
) | |
agent_name: str | |
result: Any | |
timestamp: datetime = Field(default_factory=datetime.utcnow) | |
class TreeLog(BaseModel): | |
log_id: str = Field( | |
default_factory=lambda: str(uuid.uuid4()), alias="id" | |
) | |
tree_name: str | |
task: str | |
selected_agent: str | |
timestamp: datetime = Field(default_factory=datetime.utcnow) | |
result: Any | |
def extract_keywords(prompt: str, top_n: int = 5) -> List[str]: | |
""" | |
A simplified keyword extraction function using basic word splitting instead of NLTK tokenization. | |
""" | |
words = prompt.lower().split() | |
filtered_words = [word for word in words if word.isalnum()] | |
word_counts = Counter(filtered_words) | |
return [word for word, _ in word_counts.most_common(top_n)] | |
class TreeAgent(Agent): | |
""" | |
A specialized Agent class that contains information about the system prompt's | |
locality and allows for dynamic chaining of agents in trees. | |
""" | |
def __init__( | |
self, | |
name: str = None, | |
description: str = None, | |
system_prompt: str = None, | |
model_name: str = "gpt-4o", | |
agent_name: Optional[str] = None, | |
*args, | |
**kwargs, | |
): | |
agent_name = agent_name | |
super().__init__( | |
name=name, | |
description=description, | |
system_prompt=system_prompt, | |
model_name=model_name, | |
agent_name=agent_name, | |
*args, | |
**kwargs, | |
) | |
try: | |
import sentence_transformers | |
except ImportError: | |
auto_check_and_download_package( | |
"sentence-transformers", package_manager="pip" | |
) | |
import sentence_transformers | |
self.sentence_transformers = sentence_transformers | |
# Pretrained model for embeddings | |
self.embedding_model = ( | |
sentence_transformers.SentenceTransformer( | |
"all-MiniLM-L6-v2" | |
) | |
) | |
self.system_prompt_embedding = self.embedding_model.encode( | |
system_prompt, convert_to_tensor=True | |
) | |
# Automatically extract keywords from system prompt | |
self.relevant_keywords = extract_keywords(system_prompt) | |
# Distance is now calculated based on similarity between agents' prompts | |
self.distance = None # Will be dynamically calculated later | |
def calculate_distance(self, other_agent: "TreeAgent") -> float: | |
""" | |
Calculate the distance between this agent and another agent using embedding similarity. | |
Args: | |
other_agent (TreeAgent): Another agent in the tree. | |
Returns: | |
float: Distance score between 0 and 1, with 0 being close and 1 being far. | |
""" | |
similarity = self.sentence_transformers.util.pytorch_cos_sim( | |
self.system_prompt_embedding, | |
other_agent.system_prompt_embedding, | |
).item() | |
distance = ( | |
1 - similarity | |
) # Closer agents have a smaller distance | |
return distance | |
def run_task( | |
self, task: str, img: str = None, *args, **kwargs | |
) -> Any: | |
input_log = AgentLogInput( | |
agent_name=self.agent_name, | |
task=task, | |
timestamp=datetime.now(), | |
) | |
logger.info(f"Running task on {self.agent_name}: {task}") | |
logger.debug(f"Input Log: {input_log.json()}") | |
result = self.run(task=task, img=img, *args, **kwargs) | |
output_log = AgentLogOutput( | |
agent_name=self.agent_name, | |
result=result, | |
timestamp=datetime.now(), | |
) | |
logger.info(f"Task result from {self.agent_name}: {result}") | |
logger.debug(f"Output Log: {output_log.json()}") | |
return result | |
def is_relevant_for_task( | |
self, task: str, threshold: float = 0.7 | |
) -> bool: | |
""" | |
Checks if the agent is relevant for the given task using both keyword matching and embedding similarity. | |
Args: | |
task (str): The task to be executed. | |
threshold (float): The cosine similarity threshold for embedding-based matching. | |
Returns: | |
bool: True if the agent is relevant, False otherwise. | |
""" | |
# Check if any of the relevant keywords are present in the task (case-insensitive) | |
keyword_match = any( | |
keyword.lower() in task.lower() | |
for keyword in self.relevant_keywords | |
) | |
# Perform embedding similarity match if keyword match is not found | |
if not keyword_match: | |
task_embedding = self.embedding_model.encode( | |
task, convert_to_tensor=True | |
) | |
similarity = ( | |
self.sentence_transformers.util.pytorch_cos_sim( | |
self.system_prompt_embedding, task_embedding | |
).item() | |
) | |
logger.info( | |
f"Semantic similarity between task and {self.agent_name}: {similarity:.2f}" | |
) | |
return similarity >= threshold | |
return True # Return True if keyword match is found | |
class Tree: | |
def __init__(self, tree_name: str, agents: List[TreeAgent]): | |
""" | |
Initializes a tree of agents. | |
Args: | |
tree_name (str): The name of the tree. | |
agents (List[TreeAgent]): A list of agents in the tree. | |
""" | |
self.tree_name = tree_name | |
self.agents = agents | |
self.calculate_agent_distances() | |
def calculate_agent_distances(self): | |
""" | |
Automatically calculate and assign distances between agents in the tree based on prompt similarity. | |
""" | |
logger.info( | |
f"Calculating distances between agents in tree '{self.tree_name}'" | |
) | |
for i, agent in enumerate(self.agents): | |
if i > 0: | |
agent.distance = agent.calculate_distance( | |
self.agents[i - 1] | |
) | |
else: | |
agent.distance = 0 # First agent is closest | |
# Sort agents by distance after calculation | |
self.agents.sort(key=lambda agent: agent.distance) | |
def find_relevant_agent(self, task: str) -> Optional[TreeAgent]: | |
""" | |
Finds the most relevant agent in the tree for the given task based on its system prompt. | |
Uses both keyword and semantic similarity matching. | |
Args: | |
task (str): The task or query for which we need to find a relevant agent. | |
Returns: | |
Optional[TreeAgent]: The most relevant agent, or None if no match found. | |
""" | |
logger.info( | |
f"Searching relevant agent in tree '{self.tree_name}' for task: {task}" | |
) | |
for agent in self.agents: | |
if agent.is_relevant_for_task(task): | |
return agent | |
logger.warning( | |
f"No relevant agent found in tree '{self.tree_name}' for task: {task}" | |
) | |
return None | |
def log_tree_execution( | |
self, task: str, selected_agent: TreeAgent, result: Any | |
) -> None: | |
""" | |
Logs the execution details of a tree, including selected agent and result. | |
""" | |
tree_log = TreeLog( | |
tree_name=self.tree_name, | |
task=task, | |
selected_agent=selected_agent.agent_name, | |
timestamp=datetime.now(), | |
result=result, | |
) | |
logger.info( | |
f"Tree '{self.tree_name}' executed task with agent '{selected_agent.agent_name}'" | |
) | |
logger.debug(f"Tree Log: {tree_log.json()}") | |
class ForestSwarm: | |
def __init__( | |
self, | |
name: str = "default-forest-swarm", | |
description: str = "Standard forest swarm", | |
trees: List[Tree] = [], | |
shared_memory: Any = None, | |
rules: str = None, | |
*args, | |
**kwargs, | |
): | |
""" | |
Initializes the structure with multiple trees of agents. | |
Args: | |
trees (List[Tree]): A list of trees in the structure. | |
""" | |
self.name = name | |
self.description = description | |
self.trees = trees | |
self.shared_memory = shared_memory | |
self.save_file_path = f"forest_swarm_{uuid.uuid4().hex}.json" | |
self.conversation = Conversation( | |
time_enabled=True, | |
auto_save=True, | |
save_filepath=self.save_file_path, | |
rules=rules, | |
) | |
def find_relevant_tree(self, task: str) -> Optional[Tree]: | |
""" | |
Finds the most relevant tree based on the given task. | |
Args: | |
task (str): The task or query for which we need to find a relevant tree. | |
Returns: | |
Optional[Tree]: The most relevant tree, or None if no match found. | |
""" | |
logger.info( | |
f"Searching for the most relevant tree for task: {task}" | |
) | |
for tree in self.trees: | |
if tree.find_relevant_agent(task): | |
return tree | |
logger.warning(f"No relevant tree found for task: {task}") | |
return None | |
def run(self, task: str, img: str = None, *args, **kwargs) -> Any: | |
""" | |
Executes the given task by finding the most relevant tree and agent within that tree. | |
Args: | |
task (str): The task or query to be executed. | |
Returns: | |
Any: The result of the task after it has been processed by the agents. | |
""" | |
try: | |
logger.info( | |
f"Running task across MultiAgentTreeStructure: {task}" | |
) | |
relevant_tree = self.find_relevant_tree(task) | |
if relevant_tree: | |
agent = relevant_tree.find_relevant_agent(task) | |
if agent: | |
result = agent.run_task( | |
task, img=img, *args, **kwargs | |
) | |
relevant_tree.log_tree_execution( | |
task, agent, result | |
) | |
return result | |
else: | |
logger.error( | |
"Task could not be completed: No relevant agent or tree found." | |
) | |
return "No relevant agent found to handle this task." | |
except Exception as error: | |
logger.error( | |
f"Error detected in the ForestSwarm, check your inputs and try again ;) {error}" | |
) | |
# # Example Usage: | |
# # Create agents with varying system prompts and dynamically generated distances/keywords | |
# agents_tree1 = [ | |
# TreeAgent( | |
# system_prompt="Stock Analysis Agent", | |
# agent_name="Stock Analysis Agent", | |
# ), | |
# TreeAgent( | |
# system_prompt="Financial Planning Agent", | |
# agent_name="Financial Planning Agent", | |
# ), | |
# TreeAgent( | |
# agent_name="Retirement Strategy Agent", | |
# system_prompt="Retirement Strategy Agent", | |
# ), | |
# ] | |
# agents_tree2 = [ | |
# TreeAgent( | |
# system_prompt="Tax Filing Agent", | |
# agent_name="Tax Filing Agent", | |
# ), | |
# TreeAgent( | |
# system_prompt="Investment Strategy Agent", | |
# agent_name="Investment Strategy Agent", | |
# ), | |
# TreeAgent( | |
# system_prompt="ROTH IRA Agent", agent_name="ROTH IRA Agent" | |
# ), | |
# ] | |
# # Create trees | |
# tree1 = Tree(tree_name="Financial Tree", agents=agents_tree1) | |
# tree2 = Tree(tree_name="Investment Tree", agents=agents_tree2) | |
# # Create the ForestSwarm | |
# multi_agent_structure = ForestSwarm(trees=[tree1, tree2]) | |
# # Run a task | |
# task = "Our company is incorporated in delaware, how do we do our taxes for free?" | |
# output = multi_agent_structure.run(task) | |
# print(output) | |