omni_bot / swarms /swarms /scable_groupchat.py
Zack Zitting Bradshaw
Upload folder using huggingface_hub
4962437
import logging
from enum import Enum
from typing import Any
from chromadb.utils import embedding_functions
from swarms.workers.worker import Worker
class TaskStatus(Enum):
QUEUED = 1
RUNNING = 2
COMPLETED = 3
FAILED = 4
class ScalableGroupChat:
"""
This is a class to enable scalable groupchat like a telegram, it takes an Worker as an input
and handles all the logic to enable multi-agent collaboration at massive scale.
Worker -> ScalableGroupChat(Worker * 10)
-> every response is embedded and placed in chroma
-> every response is then retrieved by querying the database and sent then passed into the prompt of the worker
-> every worker is then updated with the new response
-> every worker can communicate at any time
-> every worker can communicate without restrictions in parallel
"""
def __init__(
self,
worker_count: int = 5,
collection_name: str = "swarm",
api_key: str = None,
):
self.workers = []
self.worker_count = worker_count
self.collection_name = collection_name
self.api_key = api_key
# Create a list of Worker instances with unique names
for i in range(worker_count):
self.workers.append(
Worker(
openai_api_key=api_key,
ai_name=f"Worker-{i}"
)
)
def embed(
self,
input,
model_name
):
"""Embeds an input of size N into a vector of size M"""
openai = embedding_functions.OpenAIEmbeddingFunction(
api_key=self.api_key,
model_name=model_name
)
embedding = openai(input)
return embedding
def retrieve_results(
self,
agent_id: int
) -> Any:
"""Retrieve results from a specific agent"""
try:
#Query the vector database for documents created by the agents
results = self.collection.query(
query_texts=[str(agent_id)],
n_results=10
)
return results
except Exception as e:
logging.error(f"Failed to retrieve results from agent {agent_id}. Error {e}")
raise
# @abstractmethod
def update_vector_db(self, data) -> None:
"""Update the vector database"""
try:
self.collection.add(
embeddings=[data["vector"]],
documents=[str(data["task_id"])],
ids=[str(data["task_id"])]
)
except Exception as e:
logging.error(f"Failed to update the vector database. Error: {e}")
raise
# @abstractmethod
def get_vector_db(self):
"""Retrieve the vector database"""
return self.collection
def append_to_db(
self,
result: str
):
"""append the result of the swarm to a specifici collection in the database"""
try:
self.collection.add(
documents=[result],
ids=[str(id(result))]
)
except Exception as e:
logging.error(f"Failed to append the agent output to database. Error: {e}")
raise
def chat(
self,
sender_id: int,
receiver_id: int,
message: str
):
"""
Allows the agents to chat with eachother thrught the vectordatabase
# Instantiate the ScalableGroupChat with 10 agents
orchestrator = ScalableGroupChat(
llm,
agent_list=[llm]*10,
task_queue=[]
)
# Agent 1 sends a message to Agent 2
orchestrator.chat(sender_id=1, receiver_id=2, message="Hello, Agent 2!")
"""
if sender_id < 0 or sender_id >= self.worker_count or receiver_id < 0 or receiver_id >= self.worker_count:
raise ValueError("Invalid sender or receiver ID")
message_vector = self.embed(
message,
)
#store the mesage in the vector database
self.collection.add(
embeddings=[message_vector],
documents=[message],
ids=[f"{sender_id}_to_{receiver_id}"]
)
self.run(
objective=f"chat with agent {receiver_id} about {message}"
)