A-Mistral-Agent / src /inference.py
mriusero
core: clear status
cdb8feb
import os
import re
import json
import time
from dotenv import load_dotenv
from mistralai import Mistral
from src.utils.tooling import generate_tools_json
from src.tools import (
web_search,
visit_webpage,
retrieve_knowledge,
#load_file,
reverse_text,
analyze_chess,
#analyze_document,
classify_foods,
transcribe_audio,
execute_code,
analyze_excel,
analyze_youtube_video,
calculate_sum,
)
load_dotenv()
class Agent:
def __init__(self):
self.api_key = os.getenv("MISTRAL_API_KEY")
self.agent_id = os.getenv("AGENT_ID")
self.client = Mistral(api_key=self.api_key)
self.model = "codestral-latest"
self.prompt = None
self.names_to_functions = {
"web_search": web_search,
"visit_webpage": visit_webpage,
"retrieve_knowledge": retrieve_knowledge,
#"load_file": load_file,
"reverse_text": reverse_text,
"analyze_chess": analyze_chess,
#"analyze_document": analyze_document,
"classify_foods": classify_foods,
"transcribe_audio": transcribe_audio,
"execute_code": execute_code,
"analyze_excel": analyze_excel,
"analyze_youtube_video": analyze_youtube_video,
"calculate_sum": calculate_sum,
}
self.log = []
self.first_tools = self.get_tools(first=True)
self.all_tools = self.get_tools(first=False)
@staticmethod
def save_log(messages, task_id, truth, final_answer=None):
"""Save the conversation log to a JSON file with a timestamped filename."""
filename = f"./logs/{task_id}.json"
with open(filename, 'w', encoding='utf-8') as file:
json.dump(
messages + [{"Correct Answer": truth, "Final Answer": final_answer}],
file, ensure_ascii=False, indent=4
)
@staticmethod
def get_tools(first=None):
"""Generate the tools.json file with the tools to be used by the agent."""
if first:
return generate_tools_json(
[retrieve_knowledge]
).get('tools')
else:
return generate_tools_json(
[
web_search,
visit_webpage,
retrieve_knowledge,
# load_file,
reverse_text,
analyze_chess,
# analyze_document,
classify_foods,
transcribe_audio,
execute_code,
analyze_excel,
analyze_youtube_video,
calculate_sum,
]
).get('tools')
def make_initial_request(self, input):
"""Make the initial request to the agent with the given input."""
with open("./prompt.md", 'r', encoding='utf-8') as file:
self.prompt = file.read()
messages = [
{"role": "system", "content": self.prompt},
{"role": "user", "content": input},
{
"role": "assistant",
"content": "Let's tackle this problem, ",
"prefix": True,
},
]
payload = {
"agent_id": self.agent_id,
"messages": messages,
"max_tokens": None,
"stream": False,
"stop": None,
"random_seed": None,
"response_format": None,
"tools": self.all_tools,
"tool_choice": 'auto',
"presence_penalty": 0,
"frequency_penalty": 0,
"n": 1,
"prediction": None,
"parallel_tool_calls": None
}
return self.client.agents.complete(**payload), messages
def run(self, input, task_id, truth):
"""Run the agent with the given input and process the response."""
print("\n===== Asking the agent =====\n")
response, messages = self.make_initial_request(input)
first_iteration = True
while True:
time.sleep(1)
if hasattr(response, 'choices') and response.choices:
choice = response.choices[0]
if first_iteration:
messages = [message for message in messages if not message.get("prefix")]
messages.append(
{
"role": "assistant",
"content": choice.message.content,
"prefix": True,
},
)
first_iteration = False
else:
if choice.message.tool_calls:
results = []
for tool_call in choice.message.tool_calls:
function_name = tool_call.function.name
function_params = json.loads(tool_call.function.arguments)
try:
function_result = self.names_to_functions[function_name](**function_params)
results.append((tool_call.id, function_name, function_result))
except Exception as e:
results.append((tool_call.id, function_name, None))
for tool_call_id, function_name, function_result in results:
messages.append({
"role": "assistant",
"tool_calls": [
{
"id": tool_call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_params),
}
}
]
})
messages.append(
{
"role": "tool",
"content": function_result if function_result is not None else f"Error occurred: {function_name} failed to execute",
"tool_call_id": tool_call_id,
},
)
for message in messages:
if "prefix" in message:
del message["prefix"]
messages.append(
{
"role": "assistant",
"content": f"Based on the results, ",
"prefix": True,
}
)
else:
for message in messages:
if "prefix" in message:
del message["prefix"]
messages.append(
{
"role": "assistant",
"content": choice.message.content,
}
)
if 'FINAL ANSWER:' in choice.message.content:
print("\n===== END OF REQUEST =====\n", json.dumps(messages, indent=2))
ans = choice.message.content.split('FINAL ANSWER:')[1].strip()
self.save_log(messages, task_id, truth, final_answer=ans)
return ans
print("\n===== MESSAGES BEFORE API CALL =====\n", json.dumps(messages, indent=2))
time.sleep(1)
self.save_log(messages, task_id, truth, final_answer=None)
response = self.client.agents.complete(
agent_id=self.agent_id,
messages=messages,
tools=self.all_tools,
tool_choice='auto',
)