Spaces:
Sleeping
Sleeping
File size: 12,499 Bytes
d8d14f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 |
import uuid
from collections import Counter
from datetime import datetime
from typing import Any, List, Optional
from pydantic import BaseModel, Field
from swarms.structs.agent import Agent
from swarms.utils.loguru_logger import initialize_logger
from swarms.utils.auto_download_check_packages import (
auto_check_and_download_package,
)
from swarms.structs.conversation import Conversation
logger = initialize_logger(log_folder="tree_swarm")
# Pydantic Models for Logging
class AgentLogInput(BaseModel):
log_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), alias="id"
)
agent_name: str
task: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
class AgentLogOutput(BaseModel):
log_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), alias="id"
)
agent_name: str
result: Any
timestamp: datetime = Field(default_factory=datetime.utcnow)
class TreeLog(BaseModel):
log_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), alias="id"
)
tree_name: str
task: str
selected_agent: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
result: Any
def extract_keywords(prompt: str, top_n: int = 5) -> List[str]:
"""
A simplified keyword extraction function using basic word splitting instead of NLTK tokenization.
"""
words = prompt.lower().split()
filtered_words = [word for word in words if word.isalnum()]
word_counts = Counter(filtered_words)
return [word for word, _ in word_counts.most_common(top_n)]
class TreeAgent(Agent):
"""
A specialized Agent class that contains information about the system prompt's
locality and allows for dynamic chaining of agents in trees.
"""
def __init__(
self,
name: str = None,
description: str = None,
system_prompt: str = None,
model_name: str = "gpt-4o",
agent_name: Optional[str] = None,
*args,
**kwargs,
):
agent_name = agent_name
super().__init__(
name=name,
description=description,
system_prompt=system_prompt,
model_name=model_name,
agent_name=agent_name,
*args,
**kwargs,
)
try:
import sentence_transformers
except ImportError:
auto_check_and_download_package(
"sentence-transformers", package_manager="pip"
)
import sentence_transformers
self.sentence_transformers = sentence_transformers
# Pretrained model for embeddings
self.embedding_model = (
sentence_transformers.SentenceTransformer(
"all-MiniLM-L6-v2"
)
)
self.system_prompt_embedding = self.embedding_model.encode(
system_prompt, convert_to_tensor=True
)
# Automatically extract keywords from system prompt
self.relevant_keywords = extract_keywords(system_prompt)
# Distance is now calculated based on similarity between agents' prompts
self.distance = None # Will be dynamically calculated later
def calculate_distance(self, other_agent: "TreeAgent") -> float:
"""
Calculate the distance between this agent and another agent using embedding similarity.
Args:
other_agent (TreeAgent): Another agent in the tree.
Returns:
float: Distance score between 0 and 1, with 0 being close and 1 being far.
"""
similarity = self.sentence_transformers.util.pytorch_cos_sim(
self.system_prompt_embedding,
other_agent.system_prompt_embedding,
).item()
distance = (
1 - similarity
) # Closer agents have a smaller distance
return distance
def run_task(
self, task: str, img: str = None, *args, **kwargs
) -> Any:
input_log = AgentLogInput(
agent_name=self.agent_name,
task=task,
timestamp=datetime.now(),
)
logger.info(f"Running task on {self.agent_name}: {task}")
logger.debug(f"Input Log: {input_log.json()}")
result = self.run(task=task, img=img, *args, **kwargs)
output_log = AgentLogOutput(
agent_name=self.agent_name,
result=result,
timestamp=datetime.now(),
)
logger.info(f"Task result from {self.agent_name}: {result}")
logger.debug(f"Output Log: {output_log.json()}")
return result
def is_relevant_for_task(
self, task: str, threshold: float = 0.7
) -> bool:
"""
Checks if the agent is relevant for the given task using both keyword matching and embedding similarity.
Args:
task (str): The task to be executed.
threshold (float): The cosine similarity threshold for embedding-based matching.
Returns:
bool: True if the agent is relevant, False otherwise.
"""
# Check if any of the relevant keywords are present in the task (case-insensitive)
keyword_match = any(
keyword.lower() in task.lower()
for keyword in self.relevant_keywords
)
# Perform embedding similarity match if keyword match is not found
if not keyword_match:
task_embedding = self.embedding_model.encode(
task, convert_to_tensor=True
)
similarity = (
self.sentence_transformers.util.pytorch_cos_sim(
self.system_prompt_embedding, task_embedding
).item()
)
logger.info(
f"Semantic similarity between task and {self.agent_name}: {similarity:.2f}"
)
return similarity >= threshold
return True # Return True if keyword match is found
class Tree:
def __init__(self, tree_name: str, agents: List[TreeAgent]):
"""
Initializes a tree of agents.
Args:
tree_name (str): The name of the tree.
agents (List[TreeAgent]): A list of agents in the tree.
"""
self.tree_name = tree_name
self.agents = agents
self.calculate_agent_distances()
def calculate_agent_distances(self):
"""
Automatically calculate and assign distances between agents in the tree based on prompt similarity.
"""
logger.info(
f"Calculating distances between agents in tree '{self.tree_name}'"
)
for i, agent in enumerate(self.agents):
if i > 0:
agent.distance = agent.calculate_distance(
self.agents[i - 1]
)
else:
agent.distance = 0 # First agent is closest
# Sort agents by distance after calculation
self.agents.sort(key=lambda agent: agent.distance)
def find_relevant_agent(self, task: str) -> Optional[TreeAgent]:
"""
Finds the most relevant agent in the tree for the given task based on its system prompt.
Uses both keyword and semantic similarity matching.
Args:
task (str): The task or query for which we need to find a relevant agent.
Returns:
Optional[TreeAgent]: The most relevant agent, or None if no match found.
"""
logger.info(
f"Searching relevant agent in tree '{self.tree_name}' for task: {task}"
)
for agent in self.agents:
if agent.is_relevant_for_task(task):
return agent
logger.warning(
f"No relevant agent found in tree '{self.tree_name}' for task: {task}"
)
return None
def log_tree_execution(
self, task: str, selected_agent: TreeAgent, result: Any
) -> None:
"""
Logs the execution details of a tree, including selected agent and result.
"""
tree_log = TreeLog(
tree_name=self.tree_name,
task=task,
selected_agent=selected_agent.agent_name,
timestamp=datetime.now(),
result=result,
)
logger.info(
f"Tree '{self.tree_name}' executed task with agent '{selected_agent.agent_name}'"
)
logger.debug(f"Tree Log: {tree_log.json()}")
class ForestSwarm:
def __init__(
self,
name: str = "default-forest-swarm",
description: str = "Standard forest swarm",
trees: List[Tree] = [],
shared_memory: Any = None,
rules: str = None,
*args,
**kwargs,
):
"""
Initializes the structure with multiple trees of agents.
Args:
trees (List[Tree]): A list of trees in the structure.
"""
self.name = name
self.description = description
self.trees = trees
self.shared_memory = shared_memory
self.save_file_path = f"forest_swarm_{uuid.uuid4().hex}.json"
self.conversation = Conversation(
time_enabled=True,
auto_save=True,
save_filepath=self.save_file_path,
rules=rules,
)
def find_relevant_tree(self, task: str) -> Optional[Tree]:
"""
Finds the most relevant tree based on the given task.
Args:
task (str): The task or query for which we need to find a relevant tree.
Returns:
Optional[Tree]: The most relevant tree, or None if no match found.
"""
logger.info(
f"Searching for the most relevant tree for task: {task}"
)
for tree in self.trees:
if tree.find_relevant_agent(task):
return tree
logger.warning(f"No relevant tree found for task: {task}")
return None
def run(self, task: str, img: str = None, *args, **kwargs) -> Any:
"""
Executes the given task by finding the most relevant tree and agent within that tree.
Args:
task (str): The task or query to be executed.
Returns:
Any: The result of the task after it has been processed by the agents.
"""
try:
logger.info(
f"Running task across MultiAgentTreeStructure: {task}"
)
relevant_tree = self.find_relevant_tree(task)
if relevant_tree:
agent = relevant_tree.find_relevant_agent(task)
if agent:
result = agent.run_task(
task, img=img, *args, **kwargs
)
relevant_tree.log_tree_execution(
task, agent, result
)
return result
else:
logger.error(
"Task could not be completed: No relevant agent or tree found."
)
return "No relevant agent found to handle this task."
except Exception as error:
logger.error(
f"Error detected in the ForestSwarm, check your inputs and try again ;) {error}"
)
# # Example Usage:
# # Create agents with varying system prompts and dynamically generated distances/keywords
# agents_tree1 = [
# TreeAgent(
# system_prompt="Stock Analysis Agent",
# agent_name="Stock Analysis Agent",
# ),
# TreeAgent(
# system_prompt="Financial Planning Agent",
# agent_name="Financial Planning Agent",
# ),
# TreeAgent(
# agent_name="Retirement Strategy Agent",
# system_prompt="Retirement Strategy Agent",
# ),
# ]
# agents_tree2 = [
# TreeAgent(
# system_prompt="Tax Filing Agent",
# agent_name="Tax Filing Agent",
# ),
# TreeAgent(
# system_prompt="Investment Strategy Agent",
# agent_name="Investment Strategy Agent",
# ),
# TreeAgent(
# system_prompt="ROTH IRA Agent", agent_name="ROTH IRA Agent"
# ),
# ]
# # Create trees
# tree1 = Tree(tree_name="Financial Tree", agents=agents_tree1)
# tree2 = Tree(tree_name="Investment Tree", agents=agents_tree2)
# # Create the ForestSwarm
# multi_agent_structure = ForestSwarm(trees=[tree1, tree2])
# # Run a task
# task = "Our company is incorporated in delaware, how do we do our taxes for free?"
# output = multi_agent_structure.run(task)
# print(output)
|