Spaces:
Sleeping
Sleeping
from smolagents import ( | |
CodeAgent, | |
LiteLLMModel, | |
DuckDuckGoSearchTool, | |
PythonInterpreterTool, | |
VisitWebpageTool, | |
) | |
from src.tools import ( | |
transcribe_audio_file, | |
transcribe_from_youtube, | |
read_excel_file, | |
wiki_search, | |
multiply, | |
add, | |
subtract, | |
divide, | |
modulus, | |
) | |
import os | |
from typing import List | |
from PIL import Image | |
from dotenv import load_dotenv | |
load_dotenv() | |
SYSTEM_PROMPT = """ | |
You are a helpful assistant tasked with answering questions using a set of tools. | |
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: | |
FINAL ANSWER: [YOUR FINAL ANSWER]. | |
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
Your answer should only start with "FINAL ANSWER: ", then follows with the answer. | |
""" | |
class CustomAgent: | |
def __init__( | |
self, | |
model_id: str = "gemini/gemini-2.0-flash", | |
additional_imports: List[str] = None, | |
logging=False, | |
max_steps=10, | |
verbose: bool = False, | |
executor_type: str = "local", | |
timeout: int = 120, | |
): | |
""" | |
Initialize the CustomAgent with a model and tools. | |
If no model is provided, a default one is used. | |
""" | |
self.logging = logging | |
self.verbose = verbose | |
self.imports = [ | |
"pandas", | |
"numpy", | |
"io", | |
"datetime", | |
"json", | |
"re", | |
"math", | |
"os", | |
"requests", | |
"csv", | |
"urllib", | |
"youtube-transcript-api", | |
"SpeechRecognition", | |
"pydub", | |
] | |
if additional_imports: | |
self.imports.extend(additional_imports) | |
# Initialize tools | |
self.tools = [ | |
DuckDuckGoSearchTool(), | |
PythonInterpreterTool(), | |
VisitWebpageTool(), | |
wiki_search, | |
transcribe_audio_file, | |
transcribe_from_youtube, | |
read_excel_file, | |
multiply, | |
add, | |
subtract, | |
divide, | |
modulus, | |
] | |
# Initialize the model | |
model = LiteLLMModel( | |
model_id=model_id, | |
api_key=os.getenv("GEMINI_API_KEY"), | |
timeout=timeout, | |
) | |
# Initialize the CodeAgent | |
self.agent = CodeAgent( | |
model=model, | |
tools=self.tools, | |
additional_authorized_imports=self.imports, | |
executor_type=executor_type, | |
max_steps=max_steps, | |
verbosity_level=2 if verbose else 0, | |
) | |
if self.verbose: | |
print("CustomAgent initialized.") | |
def forward(self, question: str, file_path) -> str: | |
print(f"QUESTION: {question[:100]}...") | |
try: | |
full_prompt = f"""Question: {question} | |
{SYSTEM_PROMPT}""" | |
if file_path: | |
file_path_ext = os.path.splitext(file_path)[1] | |
if file_path_ext.lower() in [".jpg", ".jpeg", ".png"]: | |
image = Image.open(file_path).convert("RGB") | |
answer = self.agent.run(full_prompt, images=[image]) | |
elif file_path_ext.lower() in [".txt", ".py"]: | |
with open(file_path, "r") as f: | |
content = f.read() | |
full_prompt = f"""Question: {question} | |
File content: ```{content}``` | |
{SYSTEM_PROMPT}""" | |
answer = self.agent.run(full_prompt) | |
else: | |
full_prompt = f"""Question: {question} | |
File path: {file_path} | |
{SYSTEM_PROMPT}""" | |
answer = self.agent.run(full_prompt) | |
else: | |
answer = self.agent.run(full_prompt) | |
answer = self._clean_answer(answer) | |
return answer | |
except Exception as e: | |
error_msg = f"Error answering question: {e}" | |
if self.verbose: | |
print(error_msg) | |
return error_msg | |
def _clean_answer(self, answer: any) -> str: | |
""" | |
Clean up the answer to remove common prefixes and formatting | |
that models often add but that can cause exact match failures. | |
Args: | |
answer: The raw answer from the model | |
Returns: | |
The cleaned answer as a string | |
""" | |
# Convert non-string types to strings | |
if not isinstance(answer, str): | |
if isinstance(answer, float): | |
if answer.is_integer(): | |
formatted_answer = str(int(answer)) | |
else: | |
if abs(answer) >= 1000: | |
formatted_answer = f"${answer:,.2f}" | |
else: | |
formatted_answer = str(answer) | |
return formatted_answer | |
elif isinstance(answer, int): | |
return str(answer) | |
else: | |
return str(answer) | |
# Normalize whitespace | |
answer = answer.strip() | |
# Remove common prefixes and formatting that models add | |
prefixes_to_remove = [ | |
"The answer is ", | |
"Answer: ", | |
"Final answer: ", | |
"The result is ", | |
"To answer this question: ", | |
"Based on the information provided, ", | |
"According to the information: ", | |
] | |
for prefix in prefixes_to_remove: | |
if answer.startswith(prefix): | |
answer = answer[len(prefix) :].strip() | |
# Remove quotes if they wrap the entire answer | |
if (answer.startswith('"') and answer.endswith('"')) or ( | |
answer.startswith("'") and answer.endswith("'") | |
): | |
answer = answer[1:-1].strip() | |
return answer | |
def get_config(): | |
""" | |
Get the agent configuration based on environment variables | |
""" | |
# Default configuration | |
config = { | |
"model_id": "gemini/gemini-2.5-flash-preview-04-17", | |
"logging": False, | |
"max_steps": 10, | |
"verbose": False, | |
"executor_type": "local", | |
"timeout": 120, | |
} | |
return config | |