Spaces:
Sleeping
Sleeping
import os | |
import re | |
import json | |
import time | |
from dotenv import load_dotenv | |
from mistralai import Mistral | |
from src.utils.tooling import generate_tools_json | |
from src.tools import ( | |
web_search, | |
visit_webpage, | |
retrieve_knowledge, | |
#load_file, | |
reverse_text, | |
analyze_chess, | |
#analyze_document, | |
classify_foods, | |
transcribe_audio, | |
execute_code, | |
analyze_excel, | |
analyze_youtube_video, | |
calculate_sum, | |
) | |
load_dotenv() | |
class Agent: | |
def __init__(self): | |
self.api_key = os.getenv("MISTRAL_API_KEY") | |
self.agent_id = os.getenv("AGENT_ID") | |
self.client = Mistral(api_key=self.api_key) | |
self.model = "codestral-latest" | |
self.prompt = None | |
self.names_to_functions = { | |
"web_search": web_search, | |
"visit_webpage": visit_webpage, | |
"retrieve_knowledge": retrieve_knowledge, | |
#"load_file": load_file, | |
"reverse_text": reverse_text, | |
"analyze_chess": analyze_chess, | |
#"analyze_document": analyze_document, | |
"classify_foods": classify_foods, | |
"transcribe_audio": transcribe_audio, | |
"execute_code": execute_code, | |
"analyze_excel": analyze_excel, | |
"analyze_youtube_video": analyze_youtube_video, | |
"calculate_sum": calculate_sum, | |
} | |
self.log = [] | |
self.first_tools = self.get_tools(first=True) | |
self.all_tools = self.get_tools(first=False) | |
def save_log(messages, task_id, truth, final_answer=None): | |
"""Save the conversation log to a JSON file with a timestamped filename.""" | |
filename = f"./logs/{task_id}.json" | |
with open(filename, 'w', encoding='utf-8') as file: | |
json.dump( | |
messages + [{"Correct Answer": truth, "Final Answer": final_answer}], | |
file, ensure_ascii=False, indent=4 | |
) | |
def get_tools(first=None): | |
"""Generate the tools.json file with the tools to be used by the agent.""" | |
if first: | |
return generate_tools_json( | |
[retrieve_knowledge] | |
).get('tools') | |
else: | |
return generate_tools_json( | |
[ | |
web_search, | |
visit_webpage, | |
retrieve_knowledge, | |
# load_file, | |
reverse_text, | |
analyze_chess, | |
# analyze_document, | |
classify_foods, | |
transcribe_audio, | |
execute_code, | |
analyze_excel, | |
analyze_youtube_video, | |
calculate_sum, | |
] | |
).get('tools') | |
def make_initial_request(self, input): | |
"""Make the initial request to the agent with the given input.""" | |
with open("./prompt.md", 'r', encoding='utf-8') as file: | |
self.prompt = file.read() | |
messages = [ | |
{"role": "system", "content": self.prompt}, | |
{"role": "user", "content": input}, | |
{ | |
"role": "assistant", | |
"content": "Let's tackle this problem, ", | |
"prefix": True, | |
}, | |
] | |
payload = { | |
"agent_id": self.agent_id, | |
"messages": messages, | |
"max_tokens": None, | |
"stream": False, | |
"stop": None, | |
"random_seed": None, | |
"response_format": None, | |
"tools": self.all_tools, | |
"tool_choice": 'auto', | |
"presence_penalty": 0, | |
"frequency_penalty": 0, | |
"n": 1, | |
"prediction": None, | |
"parallel_tool_calls": None | |
} | |
return self.client.agents.complete(**payload), messages | |
def run(self, input, task_id, truth): | |
"""Run the agent with the given input and process the response.""" | |
print("\n===== Asking the agent =====\n") | |
response, messages = self.make_initial_request(input) | |
first_iteration = True | |
while True: | |
time.sleep(1) | |
if hasattr(response, 'choices') and response.choices: | |
choice = response.choices[0] | |
if first_iteration: | |
messages = [message for message in messages if not message.get("prefix")] | |
messages.append( | |
{ | |
"role": "assistant", | |
"content": choice.message.content, | |
"prefix": True, | |
}, | |
) | |
first_iteration = False | |
else: | |
if choice.message.tool_calls: | |
results = [] | |
for tool_call in choice.message.tool_calls: | |
function_name = tool_call.function.name | |
function_params = json.loads(tool_call.function.arguments) | |
try: | |
function_result = self.names_to_functions[function_name](**function_params) | |
results.append((tool_call.id, function_name, function_result)) | |
except Exception as e: | |
results.append((tool_call.id, function_name, None)) | |
for tool_call_id, function_name, function_result in results: | |
messages.append({ | |
"role": "assistant", | |
"tool_calls": [ | |
{ | |
"id": tool_call_id, | |
"type": "function", | |
"function": { | |
"name": function_name, | |
"arguments": json.dumps(function_params), | |
} | |
} | |
] | |
}) | |
messages.append( | |
{ | |
"role": "tool", | |
"content": function_result if function_result is not None else f"Error occurred: {function_name} failed to execute", | |
"tool_call_id": tool_call_id, | |
}, | |
) | |
for message in messages: | |
if "prefix" in message: | |
del message["prefix"] | |
messages.append( | |
{ | |
"role": "assistant", | |
"content": f"Based on the results, ", | |
"prefix": True, | |
} | |
) | |
else: | |
for message in messages: | |
if "prefix" in message: | |
del message["prefix"] | |
messages.append( | |
{ | |
"role": "assistant", | |
"content": choice.message.content, | |
} | |
) | |
if 'FINAL ANSWER:' in choice.message.content: | |
print("\n===== END OF REQUEST =====\n", json.dumps(messages, indent=2)) | |
ans = choice.message.content.split('FINAL ANSWER:')[1].strip() | |
self.save_log(messages, task_id, truth, final_answer=ans) | |
return ans | |
print("\n===== MESSAGES BEFORE API CALL =====\n", json.dumps(messages, indent=2)) | |
time.sleep(1) | |
self.save_log(messages, task_id, truth, final_answer=None) | |
response = self.client.agents.complete( | |
agent_id=self.agent_id, | |
messages=messages, | |
tools=self.all_tools, | |
tool_choice='auto', | |
) |