|
""" |
|
File retrieval tool for accessing files from the GAIA dataset. |
|
Handles multiple file formats including audio, text, PDFs, images, spreadsheets, and structured data. |
|
Enhanced with content transformation capabilities for better LLM readability. |
|
|
|
Required Dependencies: |
|
pip install PyPDF2 openpyxl huggingface_hub pandas |
|
|
|
For audio transcription, set HF_TOKEN environment variable. |
|
""" |
|
|
|
from smolagents import tool |
|
from datasets import load_dataset |
|
import os |
|
import json |
|
import csv |
|
import io |
|
import base64 |
|
from typing import Optional, Dict, Any |
|
import mimetypes |
|
|
|
|
|
import PyPDF2 |
|
import openpyxl |
|
import pandas as pd |
|
from huggingface_hub import InferenceClient |
|
import requests |
|
|
|
|
|
_dataset = None |
|
|
|
def get_dataset(): |
|
"""Get or load the GAIA dataset.""" |
|
global _dataset |
|
if _dataset is None: |
|
_dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", trust_remote_code=True, cache_dir="GAIA") |
|
return _dataset |
|
|
|
@tool |
|
def get_file(filename: str) -> str: |
|
""" |
|
Retrieve file content by filename. |
|
|
|
Args: |
|
filename: The name of the file to retrieve from |
|
|
|
Returns: |
|
A string containing the file content information and metadata. |
|
For binary files, returns metadata and base64-encoded content when appropriate. |
|
""" |
|
try: |
|
|
|
dataset = get_dataset() |
|
|
|
|
|
file_item = None |
|
|
|
|
|
try: |
|
|
|
validation_data = dataset["validation"] |
|
|
|
|
|
for item in validation_data: |
|
if isinstance(item, dict) and item.get("file_name") == filename: |
|
file_item = item |
|
break |
|
except Exception as e: |
|
|
|
try: |
|
|
|
validation_data = dataset.validation |
|
for item in validation_data: |
|
if isinstance(item, dict) and item.get("file_name") == filename: |
|
file_item = item |
|
break |
|
except Exception as e2: |
|
return f"Error accessing dataset: {str(e)} / {str(e2)}" |
|
|
|
if not file_item: |
|
return f"File '{filename}' not found in the GAIA dataset. Available files can be found by examining the dataset validation split." |
|
|
|
|
|
file_path = file_item.get("file_path") if isinstance(file_item, dict) else None |
|
if not file_path: |
|
return f"File '{filename}' found in dataset but no file_path available." |
|
|
|
|
|
if not os.path.exists(file_path): |
|
return f"File '{filename}' not found at expected path: {file_path}" |
|
|
|
|
|
mime_type, _ = mimetypes.guess_type(filename) |
|
file_extension = os.path.splitext(filename)[1].lower() |
|
|
|
|
|
result = f"File: {filename}\n" |
|
result += f"MIME Type: {mime_type or 'unknown'}\n" |
|
result += f"Extension: {file_extension}\n" |
|
|
|
|
|
if isinstance(file_item, dict) and "task_id" in file_item: |
|
result += f"Associated Task ID: {file_item['task_id']}\n" |
|
|
|
result += "\n" + "="*50 + "\n" |
|
result += "FILE CONTENT:\n" |
|
result += "="*50 + "\n\n" |
|
|
|
|
|
try: |
|
if _is_text_file(filename, mime_type): |
|
with open(file_path, 'r', encoding='utf-8', errors='replace') as f: |
|
content = f.read() |
|
if len(content) > 10000: |
|
content = content[:10000] + "\n\n... [Content truncated - showing first 10,000 characters]" |
|
result += content |
|
|
|
elif _is_pdf_file(filename, mime_type): |
|
result += _handle_pdf_file(file_path, filename) |
|
|
|
elif _is_excel_file(filename, mime_type): |
|
result += _handle_excel_file(file_path, filename) |
|
|
|
elif _is_csv_file(filename, mime_type): |
|
result += _handle_csv_file(file_path, filename) |
|
|
|
elif _is_audio_file(filename, mime_type): |
|
result += _handle_audio_file(file_path, filename) |
|
|
|
elif _is_image_file(filename, mime_type): |
|
with open(file_path, 'rb') as f: |
|
file_content = f.read() |
|
result += _handle_image_file(file_content, filename) |
|
|
|
elif _is_structured_data_file(filename, mime_type): |
|
with open(file_path, 'r', encoding='utf-8', errors='replace') as f: |
|
content = f.read() |
|
result += _handle_structured_data(content, filename) |
|
|
|
else: |
|
with open(file_path, 'rb') as f: |
|
file_content = f.read() |
|
result += _handle_binary_file(file_content, filename) |
|
|
|
except Exception as e: |
|
return f"Error reading file '{filename}': {str(e)}" |
|
|
|
return result |
|
|
|
except Exception as e: |
|
return f"Error retrieving file '{filename}': {str(e)}" |
|
|
|
def _is_text_file(filename: str, mime_type: Optional[str]) -> bool: |
|
"""Check if file is a text file.""" |
|
text_extensions = {'.txt', '.md', '.rtf', '.log', '.cfg', '.ini', '.conf', '.py', '.js', '.html', '.css', '.sql', '.sh', '.bat', '.r', '.cpp', '.c', '.java', '.php', '.rb', '.go', '.rs', '.ts', '.jsx', '.tsx', '.vue', '.svelte'} |
|
return ( |
|
filename.lower().endswith(tuple(text_extensions)) or |
|
(mime_type is not None and mime_type.startswith('text/')) |
|
) |
|
|
|
def _is_pdf_file(filename: str, mime_type: Optional[str]) -> bool: |
|
"""Check if file is a PDF file.""" |
|
return filename.lower().endswith('.pdf') or (mime_type == 'application/pdf') |
|
|
|
def _is_excel_file(filename: str, mime_type: Optional[str]) -> bool: |
|
"""Check if file is an Excel file.""" |
|
return filename.lower().endswith(('.xlsx', '.xls')) |
|
|
|
def _is_csv_file(filename: str, mime_type: Optional[str]) -> bool: |
|
"""Check if file is a CSV file.""" |
|
return filename.lower().endswith('.csv') or (mime_type == 'text/csv') |
|
|
|
def _is_audio_file(filename: str, mime_type: Optional[str]) -> bool: |
|
"""Check if file is an audio file.""" |
|
audio_extensions = {'.mp3', '.wav', '.m4a', '.aac', '.ogg', '.flac', '.wma'} |
|
return filename.lower().endswith(tuple(audio_extensions)) or (mime_type is not None and mime_type.startswith('audio/')) |
|
|
|
def _is_image_file(filename: str, mime_type: Optional[str]) -> bool: |
|
"""Check if file is an image file.""" |
|
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.svg', '.webp', '.tiff', '.ico'} |
|
return filename.lower().endswith(tuple(image_extensions)) or (mime_type is not None and mime_type.startswith('image/')) |
|
|
|
def _is_structured_data_file(filename: str, mime_type: Optional[str]) -> bool: |
|
"""Check if file is a structured data file.""" |
|
return filename.lower().endswith(('.json', '.xml', '.yaml', '.yml')) |
|
|
|
def _handle_pdf_file(file_path: str, filename: str) -> str: |
|
"""Extract text from PDF file.""" |
|
try: |
|
result = f"PDF TEXT CONTENT:\n" |
|
result += "="*50 + "\n" |
|
|
|
with open(file_path, 'rb') as pdf_file: |
|
pdf_reader = PyPDF2.PdfReader(pdf_file) |
|
page_count = len(pdf_reader.pages) |
|
result += f"Total pages: {page_count}\n\n" |
|
|
|
text_content = "" |
|
for page_num in range(min(10, page_count)): |
|
page = pdf_reader.pages[page_num] |
|
page_text = page.extract_text() |
|
if page_text: |
|
text_content += f"--- PAGE {page_num + 1} ---\n" |
|
text_content += page_text + "\n\n" |
|
|
|
if page_count > 10: |
|
text_content += f"... [Showing first 10 pages out of {page_count} total]\n" |
|
|
|
if len(text_content) > 15000: |
|
text_content = text_content[:15000] + "\n\n... [Content truncated]" |
|
|
|
result += text_content |
|
|
|
return result |
|
except Exception as e: |
|
return f"Error extracting PDF text: {str(e)}" |
|
|
|
def _handle_excel_file(file_path: str, filename: str) -> str: |
|
"""Extract data from Excel file.""" |
|
try: |
|
result = f"EXCEL CONTENT:\n" |
|
result += "="*50 + "\n" |
|
|
|
|
|
excel_file = pd.ExcelFile(file_path) |
|
sheet_names = excel_file.sheet_names |
|
|
|
result += f"Number of sheets: {len(sheet_names)}\n" |
|
result += f"Sheet names: {', '.join(str(name) for name in sheet_names)}\n\n" |
|
|
|
for sheet_name in sheet_names[:3]: |
|
df = pd.read_excel(file_path, sheet_name=sheet_name) |
|
result += f"SHEET: {sheet_name}\n" |
|
result += "="*30 + "\n" |
|
result += f"Dimensions: {df.shape[0]} rows × {df.shape[1]} columns\n" |
|
result += f"Columns: {list(df.columns)}\n\n" |
|
|
|
result += "First 5 rows:\n" |
|
result += df.head().to_string(index=True) + "\n\n" |
|
|
|
if len(sheet_names) > 3: |
|
result += f"... and {len(sheet_names) - 3} more sheets\n" |
|
|
|
return result |
|
except Exception as e: |
|
return f"Error reading Excel file: {str(e)}" |
|
|
|
def _handle_csv_file(file_path: str, filename: str) -> str: |
|
"""Extract data from CSV file.""" |
|
try: |
|
result = f"CSV CONTENT:\n" |
|
result += "="*50 + "\n" |
|
|
|
df = pd.read_csv(file_path) |
|
result += f"Dimensions: {df.shape[0]} rows × {df.shape[1]} columns\n" |
|
result += f"Columns: {list(df.columns)}\n\n" |
|
|
|
result += "First 10 rows:\n" |
|
result += df.head(10).to_string(index=True) + "\n" |
|
|
|
return result |
|
except Exception as e: |
|
return f"Error reading CSV file: {str(e)}" |
|
|
|
def _handle_audio_file(file_path: str, filename: str) -> str: |
|
"""Transcribe audio file.""" |
|
try: |
|
result = f"AUDIO TRANSCRIPTION:\n" |
|
result += "="*50 + "\n" |
|
|
|
if not os.environ.get("HF_TOKEN"): |
|
return "Audio transcription requires HF_TOKEN environment variable to be set." |
|
|
|
|
|
file_ext = os.path.splitext(filename)[1].lower() |
|
content_type_map = { |
|
'.mp3': 'audio/mpeg', |
|
'.wav': 'audio/wav', |
|
'.flac': 'audio/flac', |
|
'.m4a': 'audio/m4a', |
|
'.ogg': 'audio/ogg', |
|
'.webm': 'audio/webm' |
|
} |
|
content_type = content_type_map.get(file_ext, 'audio/mpeg') |
|
|
|
headers = { |
|
"Authorization": f"Bearer {os.environ['HF_TOKEN']}", |
|
"Content-Type": content_type |
|
} |
|
|
|
|
|
with open(file_path, 'rb') as audio_file: |
|
audio_data = audio_file.read() |
|
|
|
|
|
api_url = "https://api-inference.huggingface.co/models/openai/whisper-large-v3" |
|
response = requests.post(api_url, headers=headers, data=audio_data) |
|
|
|
if response.status_code == 200: |
|
transcription_output = response.json() |
|
else: |
|
return f"Error from HuggingFace API: {response.status_code} - {response.text}" |
|
|
|
|
|
if isinstance(transcription_output, dict) and 'text' in transcription_output: |
|
transcription_text = transcription_output['text'] |
|
else: |
|
transcription_text = str(transcription_output) |
|
|
|
result += transcription_text + "\n" |
|
result += "\n" + "="*50 + "\n" |
|
result += "Transcription completed using Whisper Large v3" |
|
|
|
return result |
|
except Exception as e: |
|
return f"Error transcribing audio: {str(e)}" |
|
|
|
def _handle_image_file(file_content: bytes, filename: str) -> str: |
|
"""Handle image file with base64 encoding.""" |
|
try: |
|
result = f"IMAGE CONTENT:\n" |
|
result += "="*50 + "\n" |
|
result += f"Image file: {filename}\n" |
|
result += f"File size: {len(file_content)} bytes\n" |
|
result += f"Format: {os.path.splitext(filename)[1].upper().lstrip('.')}\n\n" |
|
|
|
|
|
base64_content = base64.b64encode(file_content).decode('utf-8') |
|
result += "Base64 encoded content:\n" |
|
result += base64_content + "\n\n" |
|
|
|
result += "Note: This is the base64 encoded image data that can be decoded and analyzed." |
|
return result |
|
except Exception as e: |
|
return f"Error handling image: {str(e)}" |
|
|
|
def _handle_binary_file(file_content: bytes, filename: str) -> str: |
|
"""Handle binary files with base64 encoding.""" |
|
try: |
|
result = f"BINARY FILE CONTENT:\n" |
|
result += "="*50 + "\n" |
|
result += f"Binary file: {filename}\n" |
|
result += f"File size: {len(file_content)} bytes\n" |
|
result += f"File extension: {os.path.splitext(filename)[1]}\n\n" |
|
|
|
|
|
base64_content = base64.b64encode(file_content).decode('utf-8') |
|
result += "Base64 encoded content:\n" |
|
result += base64_content + "\n\n" |
|
|
|
result += "Note: This is the base64 encoded binary data." |
|
return result |
|
except Exception as e: |
|
return f"Error handling binary file: {str(e)}" |
|
|
|
def _handle_structured_data(content: str, filename: str) -> str: |
|
"""Handle structured data files.""" |
|
try: |
|
if filename.lower().endswith('.json'): |
|
try: |
|
data = json.loads(content) |
|
return json.dumps(data, indent=2, ensure_ascii=False) |
|
except json.JSONDecodeError: |
|
return content |
|
else: |
|
return content |
|
except Exception as e: |
|
return f"Error handling structured data: {str(e)}" |