File size: 14,172 Bytes
476b409 bbccf68 476b409 ebda551 476b409 ebda551 476b409 bbccf68 476b409 bbccf68 476b409 bbccf68 476b409 bbccf68 476b409 bbccf68 476b409 bbccf68 476b409 bbccf68 476b409 bbccf68 476b409 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 |
import os
import json
import logging
logger = logging.getLogger(__name__)
import re
import requests
import shutil
from typing import Any
import urllib.parse
from board_to_fen.predict import get_fen_from_image_path
from google import genai
from google.genai import types
from litellm import completion
from smolagents import Tool
from settings import Settings
class BaseCustomTool(Tool):
def __init__(self, settings):
super().__init__()
self.settings = settings
class GetTaskFileTool(BaseCustomTool):
name = "get_task_file_tool"
description = """If a file_name is provided, download file associated with a given task_id. Get absolute file path"""
inputs = {
"task_id": {"type": "string", "description": "Task ID (required)"},
"file_name": {"type": "string", "description": "File name (required)"},
}
output_type = "string"
def __init__(self, settings):
super().__init__(settings)
self.directory_name = "downloads"
self.create_dir()
def forward(self, task_id: str, file_name: str) -> str:
try:
response = requests.get(f"{self.settings.evaluation_api_base_url}/files/{task_id}", timeout=15)
response.raise_for_status()
with open(f"{self.directory_name}/{file_name}", 'wb') as file:
file.write(response.content)
return os.path.abspath(f"{self.directory_name}/{file_name}")
except Exception as e:
# Fetch the local file instead, dealing with rate limits, etc.
shutil.copy2(f"files/{file_name}", f"{self.directory_name}/{file_name}")
return f"{self.directory_name}/{file_name}"
def create_dir(self):
# Create the directory if it doesn't exist
if not os.path.exists(self.directory_name):
os.makedirs(self.directory_name)
logger.info(f"Directory '{self.directory_name}' created successfully.")
else:
logger.debug(f"Directory '{self.directory_name}' already exists.")
class VideoUnderstandingTool(BaseCustomTool):
name = "VideoUnderstanding"
description = "Prompt a YouTube video with questions to understand its content."
inputs = {
"youtube_url": {"type": "string", "description": "The URL of the YouTube video"},
"prompt": {"type": "string", "description": "A question or request regarding the video"},
}
output_type = "string"
def __init__(self, settings, model):
super().__init__(settings)
self.model = model
def forward(self, youtube_url: str, prompt: str) -> str:
client = genai.Client(api_key=self.settings.gemini_api_key.get_secret_value())
try:
video_description = client.models.generate_content(
model=self.model,
contents=types.Content(
parts=[
types.Part(
file_data=types.FileData(file_uri=youtube_url)
),
types.Part(text=prompt)
]
)
)
return video_description.text
except Exception as e:
logger.error(f"Error understanding video: {e}")
return False
class AudioUnderstandingTool(BaseCustomTool):
name = "AudioUnderstanding"
description = "Prompt a local audio file with questions to understand its content."
inputs = {
"file_path": {"type": "string", "description": "The local file of the audio"},
"prompt": {"type": "string", "description": "A question or request regarding the audio"},
}
output_type = "string"
def __init__(self, settings, model):
super().__init__(settings)
self.model = model
def forward(self, file_path: str, prompt: str) -> str:
client = genai.Client(api_key=self.settings.gemini_api_key.get_secret_value())
try:
mp3_file = client.files.upload(file=f"{file_path}")
audio_description = client.models.generate_content(
model=self.model,
contents=[prompt, mp3_file]
)
return audio_description.text
except Exception as e:
logger.error(f"Error understanding audio: {e}")
return False
class ConvertChessMoveTool(BaseCustomTool):
name = "ConvertChessMove"
description = "Convert a chess move from coordinate notation to algebraic notation."
inputs = {
"piece_placement": {"type": "string", "description": "The chess piece placement in plain text"},
"move": {"type": "string", "description": "The move in coordinate notation (e.g., e2e4)"},
}
output_type = "string"
def __init__(self, settings, model):
super().__init__(settings)
self.model = model
def forward(self, piece_placement: str, move: str) -> str:
move_message = (
f"Convert this chess move from coordinate notation to algebraic "
f"notation: {move}. Use the following {piece_placement}. Do not provide any additional "
"thinking or commentary in the response, the algebraic notation only."
)
messages = [{ "content": move_message, "role": "user"}]
response = completion(
model=self.model,
temperature=0.0,
messages=messages,
api_key=self.settings.openrouter_api_key.get_secret_value()
)
return response.choices[0].message.content
class BestChessMoveTool(BaseCustomTool):
name = "BestChessMove"
description = "Get best chess move in coordinate notation based on a FEN representation."
inputs = {
"fen": {"type": "string", "description": "The FEN (Forsyth-Edwards Notation) \
representation of the chess position. Example \
rn1q1rk1/pp2b1pp/2p2n2/3p1pB1/3P4/1QP2N2/PP1N1PPP/R4RK1 b - - 1 11"},
}
output_type = "string"
def forward(self, fen: str) -> str:
try:
url = f"{self.settings.chess_eval_url}?fen={urllib.parse.quote(fen)}&depth=15"
response = requests.get(url, timeout=15)
if response.status_code == 200 and json.loads(response.text)['success'] == True:
return json.loads(response.text)['bestmove'].split()[1]
else:
raise ValueError(f"Error getting chess evaluation: {response.status_code}")
except Exception as e:
logger.error(f"Error getting chess evaluation: {e}")
class ChessBoardFENTool(Tool):
name = "ChessBoardFEN"
description = "Get the FEN representation from an image of a chess board and a player turn."
inputs = {
"image_path": {"type": "string", "description": "The local file of the chess board image"},
"player_turn": {"type": "string",
"description": "The player with the next turn in the match, black or white"}
}
output_type = "string"
def _expand_fen_rank(self, rank_str):
"""
Expands a single rank string from FEN notation (e.g., 'p2b4')
into a list of 8 characters representing the squares.
Uses ' ' for empty squares.
"""
expanded_rank = []
for char in rank_str:
if char.isdigit():
# Add number of empty squares specified by the digit
expanded_rank.extend([' '] * int(char))
else:
# Add the piece character
expanded_rank.append(char)
# Validate rank length
if len(expanded_rank) != 8:
raise ValueError(f"Invalid FEN rank string (length != 8): {rank_str}")
return expanded_rank
def _compress_fen_rank(self, rank_list):
"""
Compresses a list of 8 characters (representing a rank)
back into FEN rank notation (e.g., turns [' ', 'K', ...] into '1K6').
Assumes ' ' represents an empty square.
"""
if len(rank_list) != 8:
raise ValueError(f"Invalid rank list (length != 8): {rank_list}")
compressed_rank = ""
empty_count = 0
for char in rank_list:
if char == ' ':
empty_count += 1
else:
# If we encountered a piece after empty squares, add the count
if empty_count > 0:
compressed_rank += str(empty_count)
empty_count = 0
# Add the piece
compressed_rank += char
# If the rank ends with empty squares, add the final count
if empty_count > 0:
compressed_rank += str(empty_count)
return compressed_rank
def _invert_mirror_fen(self, fen_string):
"""
Takes a FEN string, inverts the board vertically, mirrors it horizontally,
and returns the new FEN string representing this transformed view.
The other FEN fields (turn, castling, etc.) are preserved.
"""
try:
# 1. Split FEN into parts
parts = fen_string.strip().split(' ')
if len(parts) != 6:
raise ValueError("FEN string must have 6 space-separated fields.")
board_part = parts[0]
other_parts = parts[1:] # Side-to-move, castling, ep, halfmove, fullmove
# 2. Parse the board part into an 8x8 representation
rank_strings = board_part.split('/')
if len(rank_strings) != 8:
raise ValueError("FEN board part must have 8 ranks separated by '/'.")
# original_board[0] corresponds to rank 8, original_board[7] to rank 1
original_board = [self._expand_fen_rank(r) for r in rank_strings]
# 3. Create a new empty 8x8 board for the transformed state
# Using ' ' as the placeholder for empty squares
transformed_board = [[' ' for _ in range(8)] for _ in range(8)]
# 4. Apply the inversion (vertical flip) and mirror (horizontal flip)
for r in range(8): # Iterate through original rows (ranks 8 down to 1)
for c in range(8): # Iterate through original columns (files a to h)
# The piece at original [r][c] moves to transformed [7-r][7-c]
transformed_board[7 - r][7 - c] = original_board[r][c]
# 5. Generate the new FEN board string from the transformed board
# Read ranks from top (index 0 = rank 8) to bottom (index 7 = rank 1)
new_rank_strings = [self._compress_fen_rank(row) for row in transformed_board]
new_board_part = "/".join(new_rank_strings)
# 6. Reassemble the full FEN string
return " ".join([new_board_part] + other_parts)
except Exception as e:
# Return error message if parsing or processing fails
return f"Error processing FEN: {e}. Input: '{fen_string}'"
def _add_fen_game_state(self, board_placement,
side_to_move,
castling="-",
en_passant="-",
halfmove_clock=0,
fullmove_number=1):
"""
Appends standard game state information to a FEN board placement string.
Args:
board_placement (str): The board layout part of the FEN string
(e.g., "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR").
side_to_move (str): The active color ('w' for White, 'b' for Black).
Case-insensitive, will be converted to lowercase.
castling (str, optional): Castling availability string (e.g., "KQkq", "-").
Defaults to "-".
en_passant (str, optional): En passant target square string (e.g., "e3", "-").
Defaults to "-".
halfmove_clock (int, optional): The number of halfmoves since the last
capture or pawn advance. Defaults to 0.
fullmove_number (int, optional): The number of the full move. Starts at 1
and increments after Black's move. Defaults to 1.
Returns:
str: The complete FEN string including the game state,
or an error message string if inputs are invalid.
"""
# Validate side_to_move
side_to_move_lower = str(side_to_move).lower()
if side_to_move_lower not in ['w', 'b']:
return f"Error: side_to_move must be 'w' or 'b', received '{side_to_move}'"
# Validate clock values (should be non-negative integers, fullmove >= 1)
try:
halfmove_clock = int(halfmove_clock)
fullmove_number = int(fullmove_number)
if halfmove_clock < 0:
raise ValueError("halfmove_clock cannot be negative.")
if fullmove_number < 1:
raise ValueError("fullmove_number must be 1 or greater.")
except (ValueError, TypeError):
return (f"Error: halfmove_clock ('{halfmove_clock}') and "
f"fullmove_number ('{fullmove_number}') must be valid integers "
f"(non-negative and positive respectively).")
# Assemble the full FEN string using the validated/defaulted values
# Note: castling and en_passant strings are used directly as passed or defaulted.
# More complex validation could be added for them if needed.
full_fen = (f"{board_placement} {side_to_move_lower} {castling} "
f"{en_passant} {halfmove_clock} {fullmove_number}")
return full_fen
def forward(self, image_path: str, player_turn: str) -> str:
board_placement = get_fen_from_image_path(image_path)
# Inversion makes board_to_fen output Stockfish compatible
board_fen = self._add_fen_game_state(board_placement, player_turn)
board_fen_inverted = self._invert_mirror_fen(board_fen)
return board_fen_inverted
|