Mehedi2's picture
Update agent.py
d2a08ac verified
import os
import time
import re
from pathlib import Path
from typing import Optional, Union
import pandas as pd
from dotenv import load_dotenv
from tabulate import tabulate
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
FinalAnswerTool,
LiteLLMModel,
PythonInterpreterTool,
WikipediaSearchTool,
)
from smolagents.tools import Tool
# Load environment variables
load_dotenv()
class ExcelToTextTool(Tool):
"""Render an Excel worksheet as a Markdown table."""
name = "excel_to_text"
description = "Read an Excel file and return a Markdown table of the requested sheet."
inputs = {
"excel_path": {"type": "string", "description": "Path to the Excel file."},
"sheet_name": {"type": "string", "description": "Worksheet name or index. Optional.", "nullable": True},
}
output_type = "string"
def forward(self, excel_path: str, sheet_name: Optional[str] = None) -> str:
file_path = Path(excel_path).expanduser().resolve()
if not file_path.is_file():
return f"Error: Excel file not found at {file_path}"
try:
sheet: Union[str, int] = int(sheet_name) if sheet_name and sheet_name.isdigit() else sheet_name or 0
df = pd.read_excel(file_path, sheet_name=sheet)
return df.to_markdown(index=False) if hasattr(df, "to_markdown") else tabulate(df, headers="keys", tablefmt="github", showindex=False)
except Exception as e:
return f"Error reading Excel file: {e}"
class GaiaAgent:
"""
Single-model agent using Llama 4 Scout exclusively.
Why Llama 4 Scout:
- 30K TPM (highest available - 5x more than llama-3.1-8b)
- 500K context window
- Multimodal support (images, chess)
- 1K RPM
This avoids the 6K TPM bottleneck of llama-3.1-8b-instant.
"""
def __init__(self):
print("="*70)
print("βœ… GaiaAgent initialized with Llama 4 Scout (30K TPM)")
print("="*70)
self.api_key = os.getenv("GROQ_API_KEY")
if not self.api_key:
raise ValueError("GROQ_API_KEY not found in environment variables")
# Single model configuration - Llama 4 Scout for all tasks
self.model_id = "groq/meta-llama/llama-4-scout-17b-16e-instruct"
print(f"πŸ€– Model: {self.model_id}")
print(f"πŸ“Š Limits: 30K TPM | 1K RPM | 500K context | Multimodal")
print("="*70 + "\n")
# Initialize model
self.model = LiteLLMModel(
model_id=self.model_id,
api_key=self.api_key,
)
# Tools
self.tools = [
DuckDuckGoSearchTool(),
WikipediaSearchTool(),
ExcelToTextTool(),
PythonInterpreterTool(),
FinalAnswerTool(),
]
# Create agent
self.agent = CodeAgent(
model=self.model,
tools=self.tools,
add_base_tools=True,
additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"],
)
# Rate limiting - 30K TPM is generous but agents make multiple calls
self.last_call_time = 0
self.min_delay = 10 # 10s between tasks (reasonable with 30K TPM)
self.max_retries = 3 # More retries since we have higher TPM
# Stats
self.total_tasks = 0
self.successful_tasks = 0
self.failed_tasks = 0
self.rate_limit_hits = 0
def _extract_wait_time(self, error_str: str) -> float:
"""Extract wait time from rate limit error message."""
patterns = [
r'try again in (\d+\.?\d*)\s*s',
r'retry in (\d+\.?\d*)\s*s',
r'(\d+\.?\d*)\s*s',
]
for pattern in patterns:
match = re.search(pattern, error_str)
if match:
return float(match.group(1)) + 5 # Add 5s buffer
return 30 # Default fallback
def __call__(self, task_id: str, question: str) -> str:
"""Process a task with automatic rate limiting and retry."""
self.total_tasks += 1
# Rate limiting
elapsed = time.time() - self.last_call_time
if elapsed < self.min_delay:
wait_time = self.min_delay - elapsed
print(f"⏳ Rate limit: waiting {wait_time:.1f}s...")
time.sleep(wait_time)
print(f"\n{'='*70}")
print(f"πŸ“‹ Task #{self.total_tasks} | ID: {task_id}")
print(f"❓ Question: {question[:150]}{'...' if len(question) > 150 else ''}")
print(f"{'='*70}\n")
answer = None
# Retry loop with exponential backoff
for attempt in range(self.max_retries + 1):
try:
print(f"πŸš€ Attempt {attempt + 1}/{self.max_retries + 1}")
answer = self.agent.run(question)
if answer and len(str(answer).strip()) > 0:
self.successful_tasks += 1
print(f"βœ… Success!")
break
else:
print(f"⚠️ Empty answer received")
if attempt < self.max_retries:
time.sleep(5)
continue
except Exception as e:
error_str = str(e)
# Show condensed error
if len(error_str) > 300:
print(f"❌ Error: {error_str[:300]}...")
else:
print(f"❌ Error: {error_str}")
# Check if it's a rate limit error
if "rate limit" in error_str.lower() or "rate_limit" in error_str.lower():
self.rate_limit_hits += 1
wait_time = self._extract_wait_time(error_str)
if attempt < self.max_retries:
print(f"⏳ Rate limit hit. Waiting {wait_time:.1f}s before retry...")
# Show countdown for long waits
if wait_time > 10:
for remaining in range(int(wait_time), 0, -5):
print(f" ⏱️ {remaining}s remaining...", flush=True)
time.sleep(5)
else:
time.sleep(wait_time)
print(f"πŸ”„ Retrying...")
continue
else:
answer = "⚠️ Rate limit exceeded after all retries."
self.failed_tasks += 1
# Authentication error
elif "authentication" in error_str.lower() or "api key" in error_str.lower():
answer = "⚠️ Authentication failed. Check your GROQ_API_KEY."
self.failed_tasks += 1
break
# Other errors
else:
if attempt < self.max_retries:
print(f"πŸ”„ Retrying in 5s...")
time.sleep(5)
continue
else:
answer = f"⚠️ Failed after {self.max_retries + 1} attempts."
self.failed_tasks += 1
# Fallback
if not answer:
answer = "⚠️ Could not generate a valid response."
self.failed_tasks += 1
# Update timing
self.last_call_time = time.time()
# Print result
print(f"\n{'='*70}")
answer_preview = str(answer)[:250] + ('...' if len(str(answer)) > 250 else '')
print(f"✍️ Answer: {answer_preview}")
print(f"{'='*70}\n")
return str(answer)
def get_stats(self) -> dict:
"""Get agent performance statistics."""
success_rate = (self.successful_tasks / self.total_tasks * 100) if self.total_tasks > 0 else 0
return {
"total_tasks": self.total_tasks,
"successful_tasks": self.successful_tasks,
"failed_tasks": self.failed_tasks,
"success_rate": f"{success_rate:.1f}%",
"rate_limit_hits": self.rate_limit_hits,
}
def print_stats(self):
"""Print agent performance statistics."""
stats = self.get_stats()
print(f"\n{'='*70}")
print(f"πŸ“Š AGENT STATISTICS")
print(f"{'='*70}")
print(f"Total Tasks: {stats['total_tasks']}")
print(f"Successful: {stats['successful_tasks']} βœ…")
print(f"Failed: {stats['failed_tasks']} ❌")
print(f"Success Rate: {stats['success_rate']}")
print(f"Rate Limit Hits: {stats['rate_limit_hits']} 🚫")
print(f"{'='*70}\n")
# Example usage
if __name__ == "__main__":
agent = GaiaAgent()
# Test
answer = agent(
task_id="test-001",
question="What is 2+2? Show your calculation."
)
agent.print_stats()