Spaces:
Sleeping
Sleeping
| # Custom tools for smolagents GAIA agent | |
| from __future__ import annotations | |
| import contextlib | |
| import io | |
| import os | |
| from typing import Any, Dict, List | |
| from smolagents import Tool | |
| # ---- 1. PythonRunTool ------------------------------------------------------ | |
| class PythonRunTool(Tool): | |
| name = "python_run" | |
| description = ( | |
| "Execute trusted Python code and return printed output " | |
| "+ repr() of the last expression (or _result variable)." | |
| ) | |
| inputs = { | |
| "code": {"type": str, "description": "Python code to execute", "required": True} | |
| } | |
| def forward(self, code: str) -> str: | |
| buf, ns = io.StringIO(), {} | |
| last = None | |
| try: | |
| with contextlib.redirect_stdout(buf): | |
| exec(compile(code, "<agent-python>", "exec"), {}, ns) | |
| last = ns.get("_result", None) | |
| except Exception as e: | |
| raise RuntimeError(f"PythonRunTool error: {e}") from e | |
| out = buf.getvalue() | |
| return (out + (repr(last) if last is not None else "")).strip() | |
| # ---- 2. ExcelLoaderTool ---------------------------------------------------- | |
| class ExcelLoaderTool(Tool): | |
| name = "load_spreadsheet" | |
| description = ( | |
| "Read .xlsx/.xls/.csv from disk and return " | |
| "rows as a list of dictionaries with string keys." | |
| ) | |
| inputs = { | |
| "path": {"type": str, "description": "Path to .csv/.xls/.xlsx file", "required": True}, | |
| "sheet": { | |
| "type": str, | |
| "description": "Sheet name or index (optional, required for Excel files only)", | |
| "required": False, | |
| "default": None, | |
| } | |
| } | |
| def forward(self, path: str, sheet: str | int | None = None) -> List[Dict[str, Any]]: | |
| import pandas as pd | |
| if not os.path.isfile(path): | |
| raise FileNotFoundError(path) | |
| ext = os.path.splitext(path)[1].lower() | |
| if ext == ".csv": | |
| df = pd.read_csv(path) | |
| else: | |
| df = pd.read_excel(path, sheet_name=sheet) | |
| records = [{str(k): v for k, v in row.items()} for row in df.to_dict(orient="records")] | |
| return records | |
| # ---- 3. YouTubeTranscriptTool --------------------------------------------- | |
| class YouTubeTranscriptTool(Tool): | |
| name = "youtube_transcript" | |
| description = "Return the subtitles of a YouTube URL using youtube-transcript-api." | |
| inputs = { | |
| "url": {"type": str, "description": "YouTube URL", "required": True}, | |
| "lang": {"type": str, "description": "Transcript language (default: en)", "required": False, "default": "en"} | |
| } | |
| def forward(self, url: str, lang: str = "en") -> str: | |
| from urllib.parse import urlparse, parse_qs | |
| from youtube_transcript_api._api import YouTubeTranscriptApi | |
| vid = parse_qs(urlparse(url).query).get("v", [None])[0] or url.split("/")[-1] | |
| data = YouTubeTranscriptApi.get_transcript(vid, languages=[lang, "en", "en-US", "en-GB"]) | |
| return " ".join(d["text"] for d in data).strip() | |
| # ---- 4. AudioTranscriptionTool -------------------------------------------- | |
| class AudioTranscriptionTool(Tool): | |
| name = "transcribe_audio" | |
| description = "Transcribe an audio file with OpenAI Whisper, returns plain text." | |
| inputs = { | |
| "path": {"type": str, "description": "Path to audio file", "required": True}, | |
| "model": {"type": str, "description": "Model name for transcription (default: whisper-1)", "required": False, "default": "whisper-1"} | |
| } | |
| def forward(self, path: str, model: str = "whisper-1") -> str: | |
| import openai | |
| if not os.path.isfile(path): | |
| raise FileNotFoundError(path) | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| if not hasattr(openai, "Audio"): | |
| raise ImportError( | |
| "Your OpenAI package does not support Audio. " | |
| "Please upgrade it with: pip install --upgrade openai" | |
| ) | |
| with open(path, "rb") as fp: | |
| return openai.Audio.transcribe(model=model, file=fp)["text"].strip() | |
| # ---- 5. SimpleOCRTool ------------------------------------------------------ | |
| class SimpleOCRTool(Tool): | |
| name = "image_ocr" | |
| description = "Return any text spotted in an image via pytesseract OCR." | |
| inputs = { | |
| "path": {"type": str, "description": "Path to image file", "required": True} | |
| } | |
| def forward(self, path: str) -> str: | |
| from PIL import Image | |
| import pytesseract | |
| if not os.path.isfile(path): | |
| raise FileNotFoundError(path) | |
| return pytesseract.image_to_string(Image.open(path)).strip() | |
| # --------------------------------------------------------------------------- | |
| __all__ = [ | |
| "PythonRunTool", | |
| "ExcelLoaderTool", | |
| "YouTubeTranscriptTool", | |
| "AudioTranscriptionTool", | |
| "SimpleOCRTool", | |
| ] | |