Spaces:
Running
Running
from smolagents import ( | |
CodeAgent, | |
DuckDuckGoSearchTool, | |
HfApiModel, | |
LiteLLMModel, | |
OpenAIServerModel, | |
PythonInterpreterTool, | |
tool, | |
InferenceClientModel, | |
) | |
from typing import List, Dict, Any, Optional | |
import os | |
import tempfile | |
import re | |
import json | |
import requests | |
from urllib.parse import urlparse | |
def save_and_read_file(content: str, filename: Optional[str] = None) -> str: | |
""" | |
Save content to a temporary file and return the path. | |
Useful for processing files from the GAIA API. | |
Args: | |
content: The content to save to the file | |
filename: Optional filename, will generate a random name if not provided | |
Returns: | |
Path to the saved file | |
""" | |
temp_dir = tempfile.gettempdir() | |
if filename is None: | |
temp_file = tempfile.NamedTemporaryFile(delete=False) | |
filepath = temp_file.name | |
else: | |
filepath = os.path.join(temp_dir, filename) | |
# Write content to the file | |
with open(filepath, "w") as f: | |
f.write(content) | |
return f"File saved to {filepath}. You can read this file to process its contents." | |
def download_file_from_url(url: str, filename: Optional[str] = None) -> str: | |
""" | |
Download a file from a URL and save it to a temporary location. | |
Args: | |
url: The URL to download from | |
filename: Optional filename, will generate one based on URL if not provided | |
Returns: | |
Path to the downloaded file | |
""" | |
try: | |
# Parse URL to get filename if not provided | |
if not filename: | |
path = urlparse(url).path | |
filename = os.path.basename(path) | |
if not filename: | |
# Generate a random name if we couldn't extract one | |
import uuid | |
filename = f"downloaded_{uuid.uuid4().hex[:8]}" | |
# Create temporary file | |
temp_dir = tempfile.gettempdir() | |
filepath = os.path.join(temp_dir, filename) | |
# Download the file | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
# Save the file | |
with open(filepath, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
return f"File downloaded to {filepath}. You can now process this file." | |
except Exception as e: | |
return f"Error downloading file: {str(e)}" | |
def extract_text_from_image(image_path: str) -> str: | |
""" | |
Extract text from an image using pytesseract (if available). | |
Args: | |
image_path: Path to the image file | |
Returns: | |
Extracted text or error message | |
""" | |
try: | |
# Try to import pytesseract | |
import pytesseract | |
from PIL import Image | |
# Open the image | |
image = Image.open(image_path) | |
# Extract text | |
text = pytesseract.image_to_string(image) | |
return f"Extracted text from image:\n\n{text}" | |
except ImportError: | |
return "Error: pytesseract is not installed. Please install it with 'pip install pytesseract' and ensure Tesseract OCR is installed on your system." | |
except Exception as e: | |
return f"Error extracting text from image: {str(e)}" | |
def analyze_csv_file(file_path: str, query: str) -> str: | |
""" | |
Analyze a CSV file using pandas and answer a question about it. | |
Args: | |
file_path: Path to the CSV file | |
query: Question about the data | |
Returns: | |
Analysis result or error message | |
""" | |
try: | |
import pandas as pd | |
# Read the CSV file | |
df = pd.read_csv(file_path) | |
# Run various analyses based on the query | |
result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n" | |
result += f"Columns: {', '.join(df.columns)}\n\n" | |
# Add summary statistics | |
result += "Summary statistics:\n" | |
result += str(df.describe()) | |
return result | |
except ImportError: | |
return "Error: pandas is not installed. Please install it with 'pip install pandas'." | |
except Exception as e: | |
return f"Error analyzing CSV file: {str(e)}" | |
def analyze_excel_file(file_path: str, query: str) -> str: | |
""" | |
Analyze an Excel file using pandas and answer a question about it. | |
Args: | |
file_path: Path to the Excel file | |
query: Question about the data | |
Returns: | |
Analysis result or error message | |
""" | |
try: | |
import pandas as pd | |
# Read the Excel file | |
df = pd.read_excel(file_path) | |
# Run various analyses based on the query | |
result = ( | |
f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n" | |
) | |
result += f"Columns: {', '.join(df.columns)}\n\n" | |
# Add summary statistics | |
result += "Summary statistics:\n" | |
result += str(df.describe()) | |
return result | |
except ImportError: | |
return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'." | |
except Exception as e: | |
return f"Error analyzing Excel file: {str(e)}" | |
class GAIAAgent: | |
def __init__( | |
self, | |
model_type: str = "HfApiModel", | |
model_id: Optional[str] = None, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
temperature: float = 0.2, | |
executor_type: str = "local", # Changed from use_e2b to executor_type | |
additional_imports: List[str] = None, | |
additional_tools: List[Any] = None, | |
system_prompt: Optional[ | |
str | |
] = None, # We'll still accept this parameter but not use it directly | |
verbose: bool = False, | |
provider: Optional[str] = None, # Add provider for InferenceClientModel | |
timeout: Optional[int] = None, # Add timeout for InferenceClientModel | |
): | |
""" | |
Initialize a GAIAAgent with specified configuration | |
Args: | |
model_type: Type of model to use (HfApiModel, LiteLLMModel, OpenAIServerModel, InferenceClientModel) | |
model_id: ID of the model to use | |
api_key: API key for the model provider | |
api_base: Base URL for API calls | |
temperature: Temperature for text generation | |
executor_type: Type of executor for code execution ('local' or 'e2b') | |
additional_imports: Additional Python modules to allow importing | |
additional_tools: Additional tools to provide to the agent | |
system_prompt: Custom system prompt to use (not directly used, kept for backward compatibility) | |
verbose: Enable verbose logging | |
provider: Provider for InferenceClientModel (e.g., "hf-inference") | |
timeout: Timeout in seconds for API calls | |
""" | |
# Set verbosity | |
self.verbose = verbose | |
self.system_prompt = system_prompt # Store for potential future use | |
# Initialize model based on configuration | |
if model_type == "HfApiModel": | |
if api_key is None: | |
api_key = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
if not api_key: | |
raise ValueError( | |
"No Hugging Face token provided. Please set HUGGINGFACEHUB_API_TOKEN environment variable or pass api_key parameter." | |
) | |
if self.verbose: | |
print(f"Using Hugging Face token: {api_key[:5]}...") | |
self.model = HfApiModel( | |
model_id=model_id or "meta-llama/Llama-3-70B-Instruct", | |
token=api_key, | |
temperature=temperature, | |
) | |
elif model_type == "InferenceClientModel": | |
if api_key is None: | |
api_key = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
if not api_key: | |
raise ValueError( | |
"No Hugging Face token provided. Please set HUGGINGFACEHUB_API_TOKEN environment variable or pass api_key parameter." | |
) | |
if self.verbose: | |
print(f"Using Hugging Face token: {api_key[:5]}...") | |
self.model = InferenceClientModel( | |
model_id=model_id or "meta-llama/Llama-3-70B-Instruct", | |
provider=provider or "hf-inference", | |
token=api_key, | |
timeout=timeout or 120, | |
temperature=temperature, | |
) | |
elif model_type == "LiteLLMModel": | |
from smolagents import LiteLLMModel | |
self.model = LiteLLMModel( | |
model_id=model_id or "gpt-4o", | |
api_key=api_key or os.getenv("OPENAI_API_KEY"), | |
temperature=temperature, | |
) | |
elif model_type == "OpenAIServerModel": | |
# Check for xAI API key and base URL first | |
xai_api_key = os.getenv("XAI_API_KEY") | |
xai_api_base = os.getenv("XAI_API_BASE") | |
# If xAI credentials are available, use them | |
if xai_api_key and api_key is None: | |
api_key = xai_api_key | |
if self.verbose: | |
print(f"Using xAI API key: {api_key[:5]}...") | |
# If no API key specified, fall back to OPENAI_API_KEY | |
if api_key is None: | |
api_key = os.getenv("OPENAI_API_KEY") | |
if not api_key: | |
raise ValueError( | |
"No OpenAI API key provided. Please set OPENAI_API_KEY or XAI_API_KEY environment variable or pass api_key parameter." | |
) | |
# If xAI API base is available and no api_base is provided, use it | |
if xai_api_base and api_base is None: | |
api_base = xai_api_base | |
if self.verbose: | |
print(f"Using xAI API base URL: {api_base}") | |
# If no API base specified but environment variable available, use it | |
if api_base is None: | |
api_base = os.getenv("AGENT_API_BASE") | |
if api_base and self.verbose: | |
print(f"Using API base from AGENT_API_BASE: {api_base}") | |
self.model = OpenAIServerModel( | |
model_id=model_id or "gpt-4o", | |
api_key=api_key, | |
api_base=api_base, | |
temperature=temperature, | |
) | |
else: | |
raise ValueError(f"Unknown model type: {model_type}") | |
if self.verbose: | |
print(f"Initialized model: {model_type} - {model_id}") | |
# Initialize default tools | |
self.tools = [ | |
DuckDuckGoSearchTool(), | |
PythonInterpreterTool(), | |
save_and_read_file, | |
download_file_from_url, | |
analyze_csv_file, | |
analyze_excel_file, | |
] | |
# Add extract_text_from_image if PIL and pytesseract are available | |
try: | |
import pytesseract | |
from PIL import Image | |
self.tools.append(extract_text_from_image) | |
if self.verbose: | |
print("Added image processing tool") | |
except ImportError: | |
if self.verbose: | |
print("Image processing libraries not available") | |
# Add any additional tools | |
if additional_tools: | |
self.tools.extend(additional_tools) | |
if self.verbose: | |
print(f"Initialized with {len(self.tools)} tools") | |
# Setup imports allowed | |
self.imports = [ | |
"pandas", | |
"numpy", | |
"datetime", | |
"json", | |
"re", | |
"math", | |
"os", | |
"requests", | |
"csv", | |
"urllib", | |
] | |
if additional_imports: | |
self.imports.extend(additional_imports) | |
# Initialize the CodeAgent | |
executor_kwargs = {} | |
if executor_type == "e2b": | |
try: | |
# Try to import e2b dependencies to check if they're available | |
from e2b_code_interpreter import Sandbox | |
if self.verbose: | |
print("Using e2b executor") | |
except ImportError: | |
if self.verbose: | |
print("e2b dependencies not found, falling back to local executor") | |
executor_type = "local" # Fallback to local if e2b is not available | |
self.agent = CodeAgent( | |
tools=self.tools, | |
model=self.model, | |
additional_authorized_imports=self.imports, | |
executor_type=executor_type, | |
executor_kwargs=executor_kwargs, | |
verbosity_level=2 if self.verbose else 0, | |
) | |
if self.verbose: | |
print("Agent initialized and ready") | |
def answer_question( | |
self, question: str, task_file_path: Optional[str] = None | |
) -> str: | |
""" | |
Process a GAIA benchmark question and return the answer | |
Args: | |
question: The question to answer | |
task_file_path: Optional path to a file associated with the question | |
Returns: | |
The answer to the question | |
""" | |
try: | |
if self.verbose: | |
print(f"Processing question: {question}") | |
if task_file_path: | |
print(f"With associated file: {task_file_path}") | |
# Create a context with file information if available | |
context = question | |
file_content = None | |
# If there's a file, read it and include its content in the context | |
if task_file_path: | |
try: | |
with open(task_file_path, "r") as f: | |
file_content = f.read() | |
# Determine file type from extension | |
import os | |
file_ext = os.path.splitext(task_file_path)[1].lower() | |
context = f""" | |
Question: {question} | |
This question has an associated file. Here is the file content: | |
```{file_ext} | |
{file_content} | |
``` | |
Analyze the file content above to answer the question. | |
""" | |
except Exception as file_e: | |
context = f""" | |
Question: {question} | |
This question has an associated file at path: {task_file_path} | |
However, there was an error reading the file: {file_e} | |
You can still try to answer the question based on the information provided. | |
""" | |
# Check for special cases that need specific formatting | |
# Reversed text questions | |
if question.startswith(".") or ".rewsna eht sa" in question: | |
context = f""" | |
This question appears to be in reversed text. Here's the reversed version: | |
{question[::-1]} | |
Now answer the question above. Remember to format your answer exactly as requested. | |
""" | |
# Add a prompt to ensure precise answers | |
full_prompt = f"""{context} | |
When answering, provide ONLY the precise answer requested. | |
Do not include explanations, steps, reasoning, or additional text. | |
Be direct and specific. GAIA benchmark requires exact matching answers. | |
For example, if asked "What is the capital of France?", respond simply with "Paris". | |
""" | |
# Run the agent with the question | |
answer = self.agent.run(full_prompt) | |
# Clean up the answer to ensure it's in the expected format | |
# Remove common prefixes that models often add | |
answer = self._clean_answer(answer) | |
if self.verbose: | |
print(f"Generated answer: {answer}") | |
return answer | |
except Exception as e: | |
error_msg = f"Error answering question: {e}" | |
if self.verbose: | |
print(error_msg) | |
return error_msg | |
def _clean_answer(self, answer: any) -> str: | |
""" | |
Clean up the answer to remove common prefixes and formatting | |
that models often add but that can cause exact match failures. | |
Args: | |
answer: The raw answer from the model | |
Returns: | |
The cleaned answer as a string | |
""" | |
# Convert non-string types to strings | |
if not isinstance(answer, str): | |
# Handle numeric types (float, int) | |
if isinstance(answer, float): | |
# Format floating point numbers properly | |
# Check if it's an integer value in float form (e.g., 12.0) | |
if answer.is_integer(): | |
formatted_answer = str(int(answer)) | |
else: | |
# For currency values that might need formatting | |
if abs(answer) >= 1000: | |
formatted_answer = f"${answer:,.2f}" | |
else: | |
formatted_answer = str(answer) | |
return formatted_answer | |
elif isinstance(answer, int): | |
return str(answer) | |
else: | |
# For any other type | |
return str(answer) | |
# Now we know answer is a string, so we can safely use string methods | |
# Normalize whitespace | |
answer = answer.strip() | |
# Remove common prefixes and formatting that models add | |
prefixes_to_remove = [ | |
"The answer is ", | |
"Answer: ", | |
"Final answer: ", | |
"The result is ", | |
"To answer this question: ", | |
"Based on the information provided, ", | |
"According to the information: ", | |
] | |
for prefix in prefixes_to_remove: | |
if answer.startswith(prefix): | |
answer = answer[len(prefix) :].strip() | |
# Remove quotes if they wrap the entire answer | |
if (answer.startswith('"') and answer.endswith('"')) or ( | |
answer.startswith("'") and answer.endswith("'") | |
): | |
answer = answer[1:-1].strip() | |
return answer | |