Final_Assignment / gaia_web_loader.py
tonthatthienvu's picture
Clean repository without binary files
37cadfb
#!/usr/bin/env python3
"""
GAIA Question Loader - Web API version
Fetch questions directly from GAIA API instead of local files
"""
import json
import time
import logging
from typing import List, Dict, Optional
import requests
from dotenv import load_dotenv
import os
# Load environment variables
load_dotenv()
# Configure logging
logger = logging.getLogger(__name__)
def retry_with_backoff(max_retries: int = 3, initial_delay: float = 1.0, backoff_factor: float = 2.0):
"""Decorator to retry a function call with exponential backoff"""
def decorator(func):
def wrapper(*args, **kwargs):
retries = 0
delay = initial_delay
last_exception = None
while retries < max_retries:
try:
return func(*args, **kwargs)
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
last_exception = e
retries += 1
if retries < max_retries:
logger.warning(f"Retry {retries}/{max_retries} for {func.__name__} due to {type(e).__name__}. Delaying {delay:.2f}s")
time.sleep(delay)
delay *= backoff_factor
else:
logger.error(f"Max retries reached for {func.__name__}")
raise last_exception
except requests.exceptions.HTTPError as e:
if e.response and e.response.status_code in (500, 502, 503, 504):
last_exception = e
retries += 1
if retries < max_retries:
logger.warning(f"Retry {retries}/{max_retries} for {func.__name__} due to HTTP {e.response.status_code}. Delaying {delay:.2f}s")
time.sleep(delay)
delay *= backoff_factor
else:
logger.error(f"Max retries reached for {func.__name__}")
raise last_exception
else:
raise
return func(*args, **kwargs)
return wrapper
return decorator
class GAIAQuestionLoaderWeb:
"""Load and manage GAIA questions from the web API"""
def __init__(self, api_base: Optional[str] = None, username: Optional[str] = None):
self.api_base = api_base or os.getenv("GAIA_API_BASE", "https://agents-course-unit4-scoring.hf.space")
self.username = username or os.getenv("GAIA_USERNAME", "tonthatthienvu")
self.questions: List[Dict] = []
self._load_questions()
@retry_with_backoff()
def _make_request(self, method: str, endpoint: str, params: Optional[Dict] = None,
payload: Optional[Dict] = None, timeout: int = 15) -> requests.Response:
"""Make HTTP request with retry logic"""
url = f"{self.api_base}/{endpoint.lstrip('/')}"
logger.info(f"Request: {method.upper()} {url}")
try:
response = requests.request(method, url, params=params, json=payload, timeout=timeout)
response.raise_for_status()
return response
except requests.exceptions.HTTPError as e:
logger.error(f"HTTPError: {e.response.status_code} for {method.upper()} {url}")
if e.response:
logger.error(f"Response: {e.response.text[:200]}")
raise
except requests.exceptions.Timeout:
logger.error(f"Timeout: Request to {url} timed out after {timeout}s")
raise
except requests.exceptions.ConnectionError as e:
logger.error(f"ConnectionError: Could not connect to {url}. Details: {e}")
raise
def _load_questions(self):
"""Fetch all questions from the GAIA API"""
try:
logger.info(f"Fetching questions from GAIA API: {self.api_base}/questions")
response = self._make_request("get", "questions", timeout=15)
self.questions = response.json()
print(f"✅ Loaded {len(self.questions)} GAIA questions from web API")
logger.info(f"Successfully retrieved {len(self.questions)} questions from API")
except requests.exceptions.RequestException as e:
logger.error(f"Failed to fetch questions from API: {e}")
print(f"❌ Failed to load questions from web API: {e}")
self.questions = []
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response: {e}")
print(f"❌ Failed to parse questions from web API: {e}")
self.questions = []
def get_random_question(self) -> Optional[Dict]:
"""Get a random question from the API"""
try:
logger.info(f"Getting random question from: {self.api_base}/random-question")
response = self._make_request("get", "random-question", timeout=15)
question = response.json()
task_id = question.get('task_id', 'Unknown')
logger.info(f"Successfully retrieved random question: {task_id}")
return question
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get random question: {e}")
# Fallback to local random selection
import random
return random.choice(self.questions) if self.questions else None
except json.JSONDecodeError as e:
logger.error(f"Failed to parse random question response: {e}")
return None
def get_question_by_id(self, task_id: str) -> Optional[Dict]:
"""Get a specific question by task ID"""
return next((q for q in self.questions if q.get('task_id') == task_id), None)
def get_questions_by_level(self, level: str) -> List[Dict]:
"""Get all questions of a specific difficulty level"""
return [q for q in self.questions if q.get('Level') == level]
def get_questions_with_files(self) -> List[Dict]:
"""Get all questions that have associated files"""
return [q for q in self.questions if q.get('file_name')]
def get_questions_without_files(self) -> List[Dict]:
"""Get all questions that don't have associated files"""
return [q for q in self.questions if not q.get('file_name')]
def count_by_level(self) -> Dict[str, int]:
"""Count questions by difficulty level"""
levels = {}
for q in self.questions:
level = q.get('Level', 'Unknown')
levels[level] = levels.get(level, 0) + 1
return levels
def summary(self) -> Dict:
"""Get a summary of loaded questions"""
return {
'total_questions': len(self.questions),
'with_files': len(self.get_questions_with_files()),
'without_files': len(self.get_questions_without_files()),
'by_level': self.count_by_level(),
'api_base': self.api_base,
'username': self.username
}
def download_file(self, task_id: str, save_dir: str = "./downloads") -> Optional[str]:
"""Download a file associated with a question"""
try:
import os
from pathlib import Path
# Create download directory
Path(save_dir).mkdir(exist_ok=True)
logger.info(f"Downloading file for task: {task_id}")
response = self._make_request("get", f"files/{task_id}", timeout=30)
# Try to get filename from headers
filename = task_id
if 'content-disposition' in response.headers:
import re
match = re.search(r'filename="?([^"]+)"?', response.headers['content-disposition'])
if match:
filename = match.group(1)
# Save file
file_path = Path(save_dir) / filename
with open(file_path, 'wb') as f:
f.write(response.content)
logger.info(f"File downloaded successfully: {file_path}")
return str(file_path)
except requests.exceptions.RequestException as e:
logger.error(f"Failed to download file for task {task_id}: {e}")
return None
except Exception as e:
logger.error(f"Error saving file for task {task_id}: {e}")
return None
def test_api_connection(self) -> bool:
"""Test connectivity to the GAIA API"""
try:
logger.info(f"Testing API connection to: {self.api_base}")
response = self._make_request("get", "questions", timeout=10)
logger.info("✅ API connection successful")
return True
except Exception as e:
logger.error(f"❌ API connection failed: {e}")
return False