Spaces:
Sleeping
Sleeping
import datetime | |
import json | |
from typing import Any, Optional | |
import yaml | |
from swarms.structs.base_structure import BaseStructure | |
from typing import TYPE_CHECKING | |
from swarms.utils.formatter import formatter | |
if TYPE_CHECKING: | |
from swarms.structs.agent import ( | |
Agent, | |
) # Only imported during type checking | |
class Conversation(BaseStructure): | |
""" | |
A class structure to represent a conversation in a chatbot. This class is used to store the conversation history. | |
And, it can be used to save the conversation history to a file, load the conversation history from a file, and | |
display the conversation history. We can also use this class to add the conversation history to a database, query | |
the conversation history from a database, delete the conversation history from a database, update the conversation | |
history from a database, and get the conversation history from a database. | |
Args: | |
time_enabled (bool): Whether to enable timestamps for the conversation history. Default is False. | |
database (AbstractDatabase): The database to use for storing the conversation history. Default is None. | |
autosave (bool): Whether to autosave the conversation history to a file. Default is None. | |
save_filepath (str): The filepath to save the conversation history to. Default is None. | |
Methods: | |
add(role: str, content: str): Add a message to the conversation history. | |
delete(index: str): Delete a message from the conversation history. | |
update(index: str, role, content): Update a message in the conversation history. | |
query(index: str): Query a message in the conversation history. | |
search(keyword: str): Search for a message in the conversation history. | |
display_conversation(detailed: bool = False): Display the conversation history. | |
export_conversation(filename: str): Export the conversation history to a file. | |
import_conversation(filename: str): Import a conversation history from a file. | |
count_messages_by_role(): Count the number of messages by role. | |
return_history_as_string(): Return the conversation history as a string. | |
save_as_json(filename: str): Save the conversation history as a JSON file. | |
load_from_json(filename: str): Load the conversation history from a JSON file. | |
search_keyword_in_conversation(keyword: str): Search for a keyword in the conversation history. | |
pretty_print_conversation(messages): Pretty print the conversation history. | |
add_to_database(): Add the conversation history to the database. | |
query_from_database(query): Query the conversation history from the database. | |
delete_from_database(): Delete the conversation history from the database. | |
update_from_database(): Update the conversation history from the database. | |
get_from_database(): Get the conversation history from the database. | |
execute_query_from_database(query): Execute a query on the database. | |
fetch_all_from_database(): Fetch all from the database. | |
fetch_one_from_database(): Fetch one from the database. | |
Examples: | |
>>> from swarms import Conversation | |
>>> conversation = Conversation() | |
>>> conversation.add("user", "Hello, how are you?") | |
>>> conversation.add("assistant", "I am doing well, thanks.") | |
>>> conversation.display_conversation() | |
user: Hello, how are you? | |
assistant: I am doing well, thanks. | |
""" | |
def __init__( | |
self, | |
system_prompt: Optional[str] = None, | |
time_enabled: bool = False, | |
autosave: bool = False, | |
save_filepath: str = None, | |
tokenizer: Any = None, | |
context_length: int = 8192, | |
rules: str = None, | |
custom_rules_prompt: str = None, | |
user: str = "User:", | |
auto_save: bool = True, | |
save_as_yaml: bool = True, | |
save_as_json_bool: bool = False, | |
*args, | |
**kwargs, | |
): | |
super().__init__() | |
self.system_prompt = system_prompt | |
self.time_enabled = time_enabled | |
self.autosave = autosave | |
self.save_filepath = save_filepath | |
self.conversation_history = [] | |
self.tokenizer = tokenizer | |
self.context_length = context_length | |
self.rules = rules | |
self.custom_rules_prompt = custom_rules_prompt | |
self.user = user | |
self.auto_save = auto_save | |
self.save_as_yaml = save_as_yaml | |
self.save_as_json_bool = save_as_json_bool | |
# If system prompt is not None, add it to the conversation history | |
if self.system_prompt is not None: | |
self.add("System: ", self.system_prompt) | |
if self.rules is not None: | |
self.add("User", rules) | |
if custom_rules_prompt is not None: | |
self.add(user or "User", custom_rules_prompt) | |
# If tokenizer then truncate | |
if tokenizer is not None: | |
self.truncate_memory_with_tokenizer() | |
def add(self, role: str, content: str, *args, **kwargs): | |
"""Add a message to the conversation history | |
Args: | |
role (str): The role of the speaker | |
content (str): The content of the message | |
""" | |
if self.time_enabled: | |
now = datetime.datetime.now() | |
timestamp = now.strftime("%Y-%m-%d %H:%M:%S") | |
message = { | |
"role": role, | |
"content": content, | |
"timestamp": timestamp, | |
} | |
else: | |
message = { | |
"role": role, | |
"content": content, | |
} | |
self.conversation_history.append(message) | |
if self.autosave: | |
self.save_as_json(self.save_filepath) | |
def delete(self, index: str): | |
"""Delete a message from the conversation history | |
Args: | |
index (str): index of the message to delete | |
""" | |
self.conversation_history.pop(index) | |
def update(self, index: str, role, content): | |
"""Update a message in the conversation history | |
Args: | |
index (str): index of the message to update | |
role (_type_): role of the speaker | |
content (_type_): content of the message | |
""" | |
self.conversation_history[index] = { | |
"role": role, | |
"content": content, | |
} | |
def query(self, index: str): | |
"""Query a message in the conversation history | |
Args: | |
index (str): index of the message to query | |
Returns: | |
str: the message | |
""" | |
return self.conversation_history[index] | |
def search(self, keyword: str): | |
"""Search for a message in the conversation history | |
Args: | |
keyword (str): Keyword to search for | |
Returns: | |
str: description | |
""" | |
return [ | |
msg | |
for msg in self.conversation_history | |
if keyword in msg["content"] | |
] | |
def display_conversation(self, detailed: bool = False): | |
"""Display the conversation history | |
Args: | |
detailed (bool, optional): detailed. Defaults to False. | |
""" | |
for message in self.conversation_history: | |
formatter.print_panel( | |
f"{message['role']}: {message['content']}\n\n" | |
) | |
def export_conversation(self, filename: str, *args, **kwargs): | |
"""Export the conversation history to a file | |
Args: | |
filename (str): filename to export to | |
""" | |
with open(filename, "w") as f: | |
for message in self.conversation_history: | |
f.write(f"{message['role']}: {message['content']}\n") | |
def import_conversation(self, filename: str): | |
"""Import a conversation history from a file | |
Args: | |
filename (str): filename to import from | |
""" | |
with open(filename) as f: | |
for line in f: | |
role, content = line.split(": ", 1) | |
self.add(role, content.strip()) | |
def count_messages_by_role(self): | |
"""Count the number of messages by role""" | |
counts = { | |
"system": 0, | |
"user": 0, | |
"assistant": 0, | |
"function": 0, | |
} | |
for message in self.conversation_history: | |
counts[message["role"]] += 1 | |
return counts | |
def return_history_as_string(self): | |
"""Return the conversation history as a string | |
Returns: | |
str: the conversation history | |
""" | |
return "\n".join( | |
[ | |
f"{message['role']}: {message['content']}\n\n" | |
for message in self.conversation_history | |
] | |
) | |
def get_str(self): | |
return self.return_history_as_string() | |
def save_as_json(self, filename: str = None): | |
"""Save the conversation history as a JSON file | |
Args: | |
filename (str): Save the conversation history as a JSON file | |
""" | |
# Create the directory if it does not exist | |
# os.makedirs(os.path.dirname(filename), exist_ok=True) | |
if filename is not None: | |
with open(filename, "w") as f: | |
json.dump(self.conversation_history, f) | |
def load_from_json(self, filename: str): | |
"""Load the conversation history from a JSON file | |
Args: | |
filename (str): filename to load from | |
""" | |
# Load the conversation history from a JSON file | |
if filename is not None: | |
with open(filename) as f: | |
self.conversation_history = json.load(f) | |
def search_keyword_in_conversation(self, keyword: str): | |
"""Search for a keyword in the conversation history | |
Args: | |
keyword (str): keyword to search for | |
Returns: | |
str: description | |
""" | |
return [ | |
msg | |
for msg in self.conversation_history | |
if keyword in msg["content"] | |
] | |
def pretty_print_conversation(self, messages): | |
"""Pretty print the conversation history | |
Args: | |
messages (str): messages to print | |
""" | |
role_to_color = { | |
"system": "red", | |
"user": "green", | |
"assistant": "blue", | |
"tool": "magenta", | |
} | |
for message in messages: | |
if message["role"] == "system": | |
formatter.print_panel( | |
f"system: {message['content']}\n", | |
role_to_color[message["role"]], | |
) | |
elif message["role"] == "user": | |
formatter.print_panel( | |
f"user: {message['content']}\n", | |
role_to_color[message["role"]], | |
) | |
elif message["role"] == "assistant" and message.get( | |
"function_call" | |
): | |
formatter.print_panel( | |
f"assistant: {message['function_call']}\n", | |
role_to_color[message["role"]], | |
) | |
elif message["role"] == "assistant" and not message.get( | |
"function_call" | |
): | |
formatter.print_panel( | |
f"assistant: {message['content']}\n", | |
role_to_color[message["role"]], | |
) | |
elif message["role"] == "tool": | |
formatter.print_panel( | |
( | |
f"function ({message['name']}):" | |
f" {message['content']}\n" | |
), | |
role_to_color[message["role"]], | |
) | |
def truncate_memory_with_tokenizer(self): | |
""" | |
Truncates the conversation history based on the total number of tokens using a tokenizer. | |
Returns: | |
None | |
""" | |
total_tokens = 0 | |
truncated_history = [] | |
for message in self.conversation_history: | |
role = message.get("role") | |
content = message.get("content") | |
tokens = self.tokenizer.count_tokens( | |
text=content | |
) # Count the number of tokens | |
count = tokens # Assign the token count | |
total_tokens += count | |
if total_tokens <= self.context_length: | |
truncated_history.append(message) | |
else: | |
remaining_tokens = self.context_length - ( | |
total_tokens - count | |
) | |
truncated_content = content[ | |
:remaining_tokens | |
] # Truncate the content based on the remaining tokens | |
truncated_message = { | |
"role": role, | |
"content": truncated_content, | |
} | |
truncated_history.append(truncated_message) | |
break | |
self.conversation_history = truncated_history | |
def clear(self): | |
self.conversation_history = [] | |
def to_json(self): | |
return json.dumps(self.conversation_history) | |
def to_dict(self): | |
return self.conversation_history | |
def to_yaml(self): | |
return yaml.dump(self.conversation_history) | |
def get_visible_messages(self, agent: "Agent", turn: int): | |
""" | |
Get the visible messages for a given agent and turn. | |
Args: | |
agent (Agent): The agent. | |
turn (int): The turn number. | |
Returns: | |
List[Dict]: The list of visible messages. | |
""" | |
# Get the messages before the current turn | |
prev_messages = [ | |
message | |
for message in self.conversation_history | |
if message["turn"] < turn | |
] | |
visible_messages = [] | |
for message in prev_messages: | |
if ( | |
message["visible_to"] == "all" | |
or agent.agent_name in message["visible_to"] | |
): | |
visible_messages.append(message) | |
return visible_messages | |
# # Example usage | |
# conversation = Conversation() | |
# conversation.add("user", "Hello, how are you?") | |
# conversation.add("assistant", "I am doing well, thanks.") | |
# # print(conversation.to_json()) | |
# print(type(conversation.to_dict())) | |
# # print(conversation.to_yaml()) | |