Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import re | |
| from pathlib import Path | |
| from typing import Optional, Union | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| from tabulate import tabulate | |
| from smolagents import ( | |
| CodeAgent, | |
| DuckDuckGoSearchTool, | |
| FinalAnswerTool, | |
| LiteLLMModel, | |
| PythonInterpreterTool, | |
| WikipediaSearchTool, | |
| ) | |
| from smolagents.tools import Tool | |
| # Load environment variables | |
| load_dotenv() | |
| class ExcelToTextTool(Tool): | |
| """Render an Excel worksheet as a Markdown table.""" | |
| name = "excel_to_text" | |
| description = "Read an Excel file and return a Markdown table of the requested sheet." | |
| inputs = { | |
| "excel_path": {"type": "string", "description": "Path to the Excel file."}, | |
| "sheet_name": {"type": "string", "description": "Worksheet name or index. Optional.", "nullable": True}, | |
| } | |
| output_type = "string" | |
| def forward(self, excel_path: str, sheet_name: Optional[str] = None) -> str: | |
| file_path = Path(excel_path).expanduser().resolve() | |
| if not file_path.is_file(): | |
| return f"Error: Excel file not found at {file_path}" | |
| try: | |
| sheet: Union[str, int] = int(sheet_name) if sheet_name and sheet_name.isdigit() else sheet_name or 0 | |
| df = pd.read_excel(file_path, sheet_name=sheet) | |
| return df.to_markdown(index=False) if hasattr(df, "to_markdown") else tabulate(df, headers="keys", tablefmt="github", showindex=False) | |
| except Exception as e: | |
| return f"Error reading Excel file: {e}" | |
| class GaiaAgent: | |
| """ | |
| Single-model agent using Llama 4 Scout exclusively. | |
| Why Llama 4 Scout: | |
| - 30K TPM (highest available - 5x more than llama-3.1-8b) | |
| - 500K context window | |
| - Multimodal support (images, chess) | |
| - 1K RPM | |
| This avoids the 6K TPM bottleneck of llama-3.1-8b-instant. | |
| """ | |
| def __init__(self): | |
| print("="*70) | |
| print("β GaiaAgent initialized with Llama 4 Scout (30K TPM)") | |
| print("="*70) | |
| self.api_key = os.getenv("GROQ_API_KEY") | |
| if not self.api_key: | |
| raise ValueError("GROQ_API_KEY not found in environment variables") | |
| # Single model configuration - Llama 4 Scout for all tasks | |
| self.model_id = "groq/meta-llama/llama-4-scout-17b-16e-instruct" | |
| print(f"π€ Model: {self.model_id}") | |
| print(f"π Limits: 30K TPM | 1K RPM | 500K context | Multimodal") | |
| print("="*70 + "\n") | |
| # Initialize model | |
| self.model = LiteLLMModel( | |
| model_id=self.model_id, | |
| api_key=self.api_key, | |
| ) | |
| # Tools | |
| self.tools = [ | |
| DuckDuckGoSearchTool(), | |
| WikipediaSearchTool(), | |
| ExcelToTextTool(), | |
| PythonInterpreterTool(), | |
| FinalAnswerTool(), | |
| ] | |
| # Create agent | |
| self.agent = CodeAgent( | |
| model=self.model, | |
| tools=self.tools, | |
| add_base_tools=True, | |
| additional_authorized_imports=["pandas", "numpy", "csv", "subprocess", "PIL", "requests"], | |
| ) | |
| # Rate limiting - 30K TPM is generous but agents make multiple calls | |
| self.last_call_time = 0 | |
| self.min_delay = 10 # 10s between tasks (reasonable with 30K TPM) | |
| self.max_retries = 3 # More retries since we have higher TPM | |
| # Stats | |
| self.total_tasks = 0 | |
| self.successful_tasks = 0 | |
| self.failed_tasks = 0 | |
| self.rate_limit_hits = 0 | |
| def _extract_wait_time(self, error_str: str) -> float: | |
| """Extract wait time from rate limit error message.""" | |
| patterns = [ | |
| r'try again in (\d+\.?\d*)\s*s', | |
| r'retry in (\d+\.?\d*)\s*s', | |
| r'(\d+\.?\d*)\s*s', | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, error_str) | |
| if match: | |
| return float(match.group(1)) + 5 # Add 5s buffer | |
| return 30 # Default fallback | |
| def __call__(self, task_id: str, question: str) -> str: | |
| """Process a task with automatic rate limiting and retry.""" | |
| self.total_tasks += 1 | |
| # Rate limiting | |
| elapsed = time.time() - self.last_call_time | |
| if elapsed < self.min_delay: | |
| wait_time = self.min_delay - elapsed | |
| print(f"β³ Rate limit: waiting {wait_time:.1f}s...") | |
| time.sleep(wait_time) | |
| print(f"\n{'='*70}") | |
| print(f"π Task #{self.total_tasks} | ID: {task_id}") | |
| print(f"β Question: {question[:150]}{'...' if len(question) > 150 else ''}") | |
| print(f"{'='*70}\n") | |
| answer = None | |
| # Retry loop with exponential backoff | |
| for attempt in range(self.max_retries + 1): | |
| try: | |
| print(f"π Attempt {attempt + 1}/{self.max_retries + 1}") | |
| answer = self.agent.run(question) | |
| if answer and len(str(answer).strip()) > 0: | |
| self.successful_tasks += 1 | |
| print(f"β Success!") | |
| break | |
| else: | |
| print(f"β οΈ Empty answer received") | |
| if attempt < self.max_retries: | |
| time.sleep(5) | |
| continue | |
| except Exception as e: | |
| error_str = str(e) | |
| # Show condensed error | |
| if len(error_str) > 300: | |
| print(f"β Error: {error_str[:300]}...") | |
| else: | |
| print(f"β Error: {error_str}") | |
| # Check if it's a rate limit error | |
| if "rate limit" in error_str.lower() or "rate_limit" in error_str.lower(): | |
| self.rate_limit_hits += 1 | |
| wait_time = self._extract_wait_time(error_str) | |
| if attempt < self.max_retries: | |
| print(f"β³ Rate limit hit. Waiting {wait_time:.1f}s before retry...") | |
| # Show countdown for long waits | |
| if wait_time > 10: | |
| for remaining in range(int(wait_time), 0, -5): | |
| print(f" β±οΈ {remaining}s remaining...", flush=True) | |
| time.sleep(5) | |
| else: | |
| time.sleep(wait_time) | |
| print(f"π Retrying...") | |
| continue | |
| else: | |
| answer = "β οΈ Rate limit exceeded after all retries." | |
| self.failed_tasks += 1 | |
| # Authentication error | |
| elif "authentication" in error_str.lower() or "api key" in error_str.lower(): | |
| answer = "β οΈ Authentication failed. Check your GROQ_API_KEY." | |
| self.failed_tasks += 1 | |
| break | |
| # Other errors | |
| else: | |
| if attempt < self.max_retries: | |
| print(f"π Retrying in 5s...") | |
| time.sleep(5) | |
| continue | |
| else: | |
| answer = f"β οΈ Failed after {self.max_retries + 1} attempts." | |
| self.failed_tasks += 1 | |
| # Fallback | |
| if not answer: | |
| answer = "β οΈ Could not generate a valid response." | |
| self.failed_tasks += 1 | |
| # Update timing | |
| self.last_call_time = time.time() | |
| # Print result | |
| print(f"\n{'='*70}") | |
| answer_preview = str(answer)[:250] + ('...' if len(str(answer)) > 250 else '') | |
| print(f"βοΈ Answer: {answer_preview}") | |
| print(f"{'='*70}\n") | |
| return str(answer) | |
| def get_stats(self) -> dict: | |
| """Get agent performance statistics.""" | |
| success_rate = (self.successful_tasks / self.total_tasks * 100) if self.total_tasks > 0 else 0 | |
| return { | |
| "total_tasks": self.total_tasks, | |
| "successful_tasks": self.successful_tasks, | |
| "failed_tasks": self.failed_tasks, | |
| "success_rate": f"{success_rate:.1f}%", | |
| "rate_limit_hits": self.rate_limit_hits, | |
| } | |
| def print_stats(self): | |
| """Print agent performance statistics.""" | |
| stats = self.get_stats() | |
| print(f"\n{'='*70}") | |
| print(f"π AGENT STATISTICS") | |
| print(f"{'='*70}") | |
| print(f"Total Tasks: {stats['total_tasks']}") | |
| print(f"Successful: {stats['successful_tasks']} β ") | |
| print(f"Failed: {stats['failed_tasks']} β") | |
| print(f"Success Rate: {stats['success_rate']}") | |
| print(f"Rate Limit Hits: {stats['rate_limit_hits']} π«") | |
| print(f"{'='*70}\n") | |
| # Example usage | |
| if __name__ == "__main__": | |
| agent = GaiaAgent() | |
| # Test | |
| answer = agent( | |
| task_id="test-001", | |
| question="What is 2+2? Show your calculation." | |
| ) | |
| agent.print_stats() |