Spaces:
Runtime error
Runtime error
# tools.py | |
import os | |
import json | |
import logging | |
import re | |
import requests | |
import shutil | |
import urllib.parse | |
import pandas as pd # For ExcelParsingTool | |
from board_to_fen.predict import get_fen_from_image_path # For ChessBoardFENTool | |
from google import genai | |
from google.genai import types | |
# from litellm import completion # Removed - no longer used for ConvertChessMoveTool | |
from smolagents import Tool | |
from settings import Settings | |
from models import GoogleModelID # Import GoogleModelID | |
logger = logging.getLogger(__name__) | |
class BaseCustomTool(Tool): | |
"""Base class for custom tools to easily pass settings.""" | |
def __init__(self, settings: Settings): | |
super().__init__() | |
self.settings = settings | |
class GetTaskFileTool(BaseCustomTool): | |
name = "get_task_file_tool" | |
description = """If a file_name is provided in the task, use this tool to download the file associated with a given task_id. Returns the absolute file path to the downloaded file. This path can then be used by other tools like AudioUnderstandingTool or ExcelParsingTool. Example: get_task_file_tool(task_id="1234", file_name="example.mp3")""" | |
inputs = { | |
"task_id": {"type": "string", "description": "Task ID (required)"}, | |
"file_name": {"type": "string", "description": "File name (required)"}, | |
} | |
output_type = "string" | |
def __init__(self, settings: Settings): | |
super().__init__(settings) | |
self.directory_name = "downloads" | |
self.create_dir() | |
def forward(self, task_id: str, file_name: str) -> str: | |
try: | |
# Use the scoring API base URL for file downloads | |
response = requests.get(f"{self.settings.scoring_api_base_url}/files/{task_id}", timeout=15) | |
response.raise_for_status() | |
# Ensure the downloads directory exists | |
os.makedirs(self.directory_name, exist_ok=True) | |
file_path = os.path.join(self.directory_name, file_name) | |
with open(file_path, 'wb') as file: | |
file.write(response.content) | |
absolute_file_path = os.path.abspath(file_path) | |
logger.info(f"Successfully downloaded file '{file_name}' for task_id {task_id} to {absolute_file_path}") | |
return absolute_file_path | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Error downloading file for task_id {task_id} from API: {e}") | |
# Fallback to local 'files' directory if API download fails | |
local_file_path = os.path.join("files", file_name) | |
if os.path.exists(local_file_path): | |
destination_path = os.path.join(self.directory_name, file_name) | |
os.makedirs(self.directory_name, exist_ok=True) | |
shutil.copy2(local_file_path, destination_path) | |
absolute_local_file_path = os.path.abspath(destination_path) | |
logger.info(f"Copied local fallback file '{file_name}' to {absolute_local_file_path}") | |
return absolute_local_file_path | |
else: | |
logger.error(f"Local fallback file '{local_file_path}' not found.") | |
return f"Error: Could not download or find file '{file_name}' for task_id {task_id}. {e}" | |
except Exception as e: | |
logger.error(f"An unexpected error occurred in GetTaskFileTool: {e}") | |
return f"Error: An unexpected error occurred while getting file '{file_name}'. {e}" | |
def create_dir(self): | |
"""Creates the download directory if it doesn't exist.""" | |
if not os.path.exists(self.directory_name): | |
os.makedirs(self.directory_name) | |
logger.info(f"Directory '{self.directory_name}' created successfully.") | |
else: | |
logger.debug(f"Directory '{self.directory_name}' already exists.") | |
class VideoUnderstandingTool(BaseCustomTool): | |
name = "video_understanding_tool" | |
description = "Analyzes a YouTube video given its URL and a specific prompt/question about its content. Returns a text description or answer from the video. Use this for tasks involving video content. Example: video_understanding_tool(youtube_url=\"https://www.youtube.com/watch?v=VIDEO_ID\", prompt=\"What is the main topic of this video?\")" | |
inputs = { | |
"youtube_url": {"type": "string", "description": "The URL of the YouTube video"}, | |
"prompt": {"type": "string", "description": "A question or request regarding the video content"}, | |
} | |
output_type = "string" | |
def __init__(self, settings: Settings, model: GoogleModelID): | |
super().__init__(settings) | |
self.model = model | |
# Initialize Google GenAI client with API key | |
genai.configure(api_key=self.settings.gemini_api_key.get_secret_value()) | |
logger.info(f"VideoUnderstandingTool initialized with model: {self.model}") | |
def forward(self, youtube_url: str, prompt: str) -> str: | |
try: | |
# Use the genai.GenerativeModel for multimodal content | |
model_instance = genai.GenerativeModel(self.model) | |
# Create a FileData part from the YouTube URL | |
video_file_data = types.Part( | |
file_data=types.FileData( | |
file_uri=youtube_url, | |
mime_type="video/mp4" # Assuming common video type, adjust if needed | |
) | |
) | |
# Generate content with both video and text prompt | |
response = model_instance.generate_content( | |
contents=[video_file_data, types.Part(text=prompt)] | |
) | |
return response.text | |
except Exception as e: | |
logger.error(f"Error understanding video from URL '{youtube_url}': {e}") | |
return f"Error understanding video: {e}" | |
class AudioUnderstandingTool(BaseCustomTool): | |
name = "audio_understanding_tool" | |
description = "Analyzes a local audio file given its file path and a specific prompt/question about its content. Returns a text description or answer from the audio. Use this for tasks involving audio files. You must first download the audio file using 'get_task_file_tool'. Example: audio_understanding_tool(file_path=\"/tmp/audio.mp3\", prompt=\"What are the key ingredients mentioned?\")" | |
inputs = { | |
"file_path": {"type": "string", "description": "The local file path of the audio file (e.g., from get_task_file_tool)."}, | |
"prompt": {"type": "string", "description": "A question or request regarding the audio content."}, | |
} | |
output_type = "string" | |
def __init__(self, settings: Settings, model: GoogleModelID): | |
super().__init__(settings) | |
self.model = model | |
# Initialize Google GenAI client with API key | |
genai.configure(api_key=self.settings.gemini_api_key.get_secret_value()) | |
logger.info(f"AudioUnderstandingTool initialized with model: {self.model}") | |
def forward(self, file_path: str, prompt: str) -> str: | |
try: | |
# Upload the local audio file to Gemini Files API | |
mp3_file = genai.upload_file(path=file_path) | |
logger.info(f"Uploaded audio file: {mp3_file.uri}") | |
# Use the genai.GenerativeModel for multimodal content | |
model_instance = genai.GenerativeModel(self.model) | |
# Generate content with both audio file and text prompt | |
response = model_instance.generate_content( | |
contents=[mp3_file, types.Part(text=prompt)] | |
) | |
# Delete the uploaded file from Gemini Files API (optional, but good practice) | |
# genai.delete_file(mp3_file.name) # This might require a separate API call or context manager | |
return response.text | |
except Exception as e: | |
logger.error(f"Error understanding audio from file '{file_path}': {e}") | |
return f"Error understanding audio: {e}" | |
class ExcelParsingTool(BaseCustomTool): | |
name = "excel_parsing_tool" | |
description = "Parses an Excel (.xlsx) file given its local file path. It reads the first sheet and returns its content as a CSV formatted string. Use this for tasks involving Excel data. You must first download the Excel file using 'get_task_file_tool'. Example: excel_parsing_tool(file_path=\"/tmp/sales_data.xlsx\")" | |
inputs = {"file_path": {"type": "string", "description": "The local path to the Excel file (e.g., from get_task_file_tool)."}} | |
output_type = "string" | |
def __init__(self, settings: Settings): | |
super().__init__(settings) | |
logger.info("ExcelParsingTool initialized.") | |
def forward(self, file_path: str) -> str: | |
""" | |
Reads an Excel file and returns its content (first sheet) as a CSV string. | |
""" | |
try: | |
# Ensure the file exists before trying to read | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Excel file not found at: {file_path}") | |
df = pd.read_excel(file_path) | |
csv_content = df.to_csv(index=False) | |
logger.info(f"Successfully parsed Excel file: {file_path}") | |
return csv_content | |
except Exception as e: | |
logger.error(f"Error parsing Excel file {file_path}: {e}") | |
return f"Error parsing Excel file: {e}" | |
class ConvertChessMoveTool(BaseCustomTool): | |
name = "convert_chess_move_tool" | |
description = "Converts a chess move from coordinate notation (e.g., 'e2e4') to standard algebraic notation. Requires the current piece placement as plain text. Example: convert_chess_move_tool(piece_placement=\"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR\", move=\"e2e4\")" | |
inputs = { | |
"piece_placement": {"type": "string", "description": "The chess piece placement in plain text (e.g., a FEN board part)."}, | |
"move": {"type": "string", "description": "The move in coordinate notation (e.g., 'e2e4')"}, | |
} | |
output_type = "string" | |
def __init__(self, settings: Settings, model: GoogleModelID): # Changed model type to GoogleModelID | |
super().__init__(settings) | |
self.model = model | |
genai.configure(api_key=self.settings.gemini_api_key.get_secret_value()) # Configure genai for this tool | |
logger.info(f"ConvertChessMoveTool initialized with model: {self.model}") | |
def forward(self, piece_placement: str, move: str) -> str: | |
move_message = ( | |
f"Convert this chess move from coordinate notation to algebraic " | |
f"notation: {move}. Use the following board state for context: {piece_placement}. " | |
"Do not provide any additional thinking or commentary in the response, " | |
"return only the algebraic notation for the move." | |
) | |
messages = [{ "content": move_message, "role": "user"}] | |
try: | |
model_instance = genai.GenerativeModel(self.model) # Use genai.GenerativeModel | |
response = model_instance.generate_content( | |
contents=messages[0]['content'] # Pass content directly | |
) | |
return response.text | |
except Exception as e: | |
logger.error(f"Error converting chess move: {e}") | |
return f"Error converting chess move: {e}" | |
class BestChessMoveTool(BaseCustomTool): | |
name = "best_chess_move_tool" | |
description = "Gets the best chess move in coordinate notation (e.g., 'e2e4') based on a FEN (Forsyth-Edwards Notation) representation of the chess position. Example: best_chess_move_tool(fen=\"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1\")" | |
inputs = { | |
"fen": {"type": "string", "description": "The FEN (Forsyth-Edwards Notation) representation of the chess position. Example: 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1'"}, | |
} | |
output_type = "string" | |
def forward(self, fen: str) -> str: | |
try: | |
url = f"{self.settings.chess_eval_url}?fen={urllib.parse.quote(fen)}&depth=15" # Depth 15 for reasonable accuracy | |
response = requests.get(url, timeout=15) | |
response.raise_for_status() # Raise HTTPError for bad responses | |
response_json = response.json() | |
if response_json.get('success') == True and 'bestmove' in response_json: | |
# Stockfish API often returns "bestmove e2e4 ponder e7e5" | |
# We need only the move itself, which is the second part | |
best_move = response_json['bestmove'].split()[1] | |
logger.info(f"Successfully retrieved best chess move: {best_move} for FEN: {fen}") | |
return best_move | |
else: | |
raise ValueError(f"Stockfish API returned unsuccessful response or missing 'bestmove': {response_json}") | |
except Exception as e: | |
logger.error(f"Error getting best chess move for FEN '{fen}': {e}") | |
return f"Error getting best chess move: {e}" | |
class ChessBoardFENTool(Tool): | |
name = "chess_board_fen_tool" | |
description = "Generates the FEN (Forsyth-Edwards Notation) representation from a local image file of a chess board and the player whose turn it is. Returns the FEN string. You must first download the image file using 'get_task_file_tool'. Example: chess_board_fen_tool(image_path=\"/tmp/board.png\", player_turn=\"b\")" | |
inputs = { | |
"image_path": {"type": "string", "description": "The local file path of the chess board image (e.g., from get_task_file_tool)."}, | |
"player_turn": {"type": "string", "description": "The player with the next turn in the match, must be 'w' (white) or 'b' (black)."} | |
} | |
output_type = "string" | |
def _expand_fen_rank(self, rank_str): | |
"""Expands a single rank string from FEN notation into a list of 8 characters.""" | |
expanded_rank = [] | |
for char in rank_str: | |
if char.isdigit(): | |
expanded_rank.extend([' '] * int(char)) | |
else: | |
expanded_rank.append(char) | |
if len(expanded_rank) != 8: | |
raise ValueError(f"Invalid FEN rank string (length != 8): {rank_str}") | |
return expanded_rank | |
def _compress_fen_rank(self, rank_list): | |
"""Compresses a list of 8 characters (representing a rank) back into FEN rank notation.""" | |
if len(rank_list) != 8: | |
raise ValueError(f"Invalid rank list (length != 8): {rank_list}") | |
compressed_rank = "" | |
empty_count = 0 | |
for char in rank_list: | |
if char == ' ': | |
empty_count += 1 | |
else: | |
if empty_count > 0: | |
compressed_rank += str(empty_count) | |
empty_count = 0 | |
compressed_rank += char | |
if empty_count > 0: | |
compressed_rank += str(empty_count) | |
return compressed_rank | |
def _invert_mirror_fen(self, fen_string: str) -> str: | |
""" | |
Takes a FEN string, inverts the board vertically, mirrors it horizontally, | |
and returns the new FEN string representing this transformed view. | |
This is often needed to convert board_to_fen output to Stockfish compatible FEN. | |
""" | |
try: | |
parts = fen_string.strip().split(' ') | |
if len(parts) != 6: | |
raise ValueError("FEN string must have 6 space-separated fields (board, turn, castling, ep, halfmove, fullmove).") | |
board_part = parts[0] | |
other_parts = parts[1:] | |
rank_strings = board_part.split('/') | |
if len(rank_strings) != 8: | |
raise ValueError("FEN board part must have 8 ranks separated by '/'.") | |
original_board = [self._expand_fen_rank(r) for r in rank_strings] | |
transformed_board = [[' ' for _ in range(8)] for _ in range(8)] | |
for r in range(8): | |
for c in range(8): | |
transformed_board[7 - r][7 - c] = original_board[r][c] | |
new_rank_strings = [self._compress_fen_rank(row) for row in transformed_board] | |
new_board_part = "/".join(new_rank_strings) | |
return " ".join([new_board_part] + other_parts) | |
except Exception as e: | |
logger.error(f"Error processing FEN for inversion/mirroring: {e}. Input: '{fen_string}'") | |
return f"Error processing FEN: {e}" | |
def _add_fen_game_state(self, board_placement: str, | |
side_to_move: str, | |
castling: str = "-", | |
en_passant: str = "-", | |
halfmove_clock: int = 0, | |
fullmove_number: int = 1) -> str: | |
""" | |
Appends standard game state information to a FEN board placement string. | |
""" | |
side_to_move_lower = str(side_to_move).lower() | |
if side_to_move_lower not in ['w', 'b']: | |
return f"Error: side_to_move must be 'w' or 'b', received '{side_to_move}'" | |
try: | |
halfmove_clock = int(halfmove_clock) | |
fullmove_number = int(fullmove_number) | |
if halfmove_clock < 0: | |
raise ValueError("halfmove_clock cannot be negative.") | |
if fullmove_number < 1: | |
raise ValueError("fullmove_number must be 1 or greater.") | |
except (ValueError, TypeError): | |
return (f"Error: halfmove_clock ('{halfmove_clock}') and " | |
f"fullmove_number ('{fullmove_number}') must be valid integers " | |
f"(non-negative and positive respectively).") | |
full_fen = (f"{board_placement} {side_to_move_lower} {castling} " | |
f"{en_passant} {halfmove_clock} {fullmove_number}") | |
return full_fen | |
def forward(self, image_path: str, player_turn: str) -> str: | |
try: | |
board_placement = get_fen_from_image_path(image_path) | |
# Add game state to the board placement | |
board_fen_with_state = self._add_fen_game_state(board_placement, player_turn) | |
# Inversion makes board_to_fen output Stockfish compatible | |
board_fen_inverted = self._invert_mirror_fen(board_fen_with_state) | |
logger.info(f"Generated FEN from image '{image_path}': {board_fen_inverted}") | |
return board_fen_inverted | |
except Exception as e: | |
logger.error(f"Error generating FEN from image '{image_path}': {e}") | |
return f"Error generating FEN from image: {e}" | |