chess_test / utils.py
Br0wks's picture
Chess Challenge submission by Br0wks
b83b66b verified
"""
Utility functions for the Chess Challenge.
This module provides helper functions for:
- Parameter counting and budget analysis
- Model registration with Hugging Face
- Move validation with python-chess
"""
from __future__ import annotations
from typing import Dict, Optional, TYPE_CHECKING
import torch.nn as nn
if TYPE_CHECKING:
from src.model import ChessConfig
def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
"""
Count the number of parameters in a model.
Args:
model: The PyTorch model.
trainable_only: If True, only count trainable parameters.
Returns:
Total number of parameters.
"""
if trainable_only:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
return sum(p.numel() for p in model.parameters())
def count_parameters_by_component(model: nn.Module) -> Dict[str, int]:
"""
Count parameters broken down by model component.
Args:
model: The PyTorch model.
Returns:
Dictionary mapping component names to parameter counts.
"""
counts = {}
for name, module in model.named_modules():
if len(list(module.children())) == 0: # Leaf module
param_count = sum(p.numel() for p in module.parameters(recurse=False))
if param_count > 0:
counts[name] = param_count
return counts
def estimate_parameters(config: "ChessConfig") -> Dict[str, int]:
"""
Estimate the parameter count for a given configuration.
This is useful for planning your architecture before building the model.
Args:
config: Model configuration.
Returns:
Dictionary with estimated parameter counts by component.
"""
V = config.vocab_size
d = config.n_embd
L = config.n_layer
n_ctx = config.n_ctx
n_inner = config.n_inner
estimates = {
"token_embeddings": V * d,
"position_embeddings": n_ctx * d,
"attention_qkv_per_layer": 3 * d * d,
"attention_proj_per_layer": d * d,
"ffn_per_layer": 2 * d * n_inner,
"layernorm_per_layer": 4 * d, # 2 LayerNorms, each with weight and bias
"final_layernorm": 2 * d,
}
# Calculate totals
per_layer = (
estimates["attention_qkv_per_layer"] +
estimates["attention_proj_per_layer"] +
estimates["ffn_per_layer"] +
estimates["layernorm_per_layer"]
)
estimates["total_transformer_layers"] = L * per_layer
# LM head (tied with embeddings by default)
if config.tie_weights:
estimates["lm_head"] = 0
estimates["lm_head_note"] = "Tied with token embeddings"
else:
estimates["lm_head"] = V * d
# Grand total
estimates["total"] = (
estimates["token_embeddings"] +
estimates["position_embeddings"] +
estimates["total_transformer_layers"] +
estimates["final_layernorm"] +
estimates["lm_head"]
)
return estimates
def print_parameter_budget(config: "ChessConfig", limit: int = 1_000_000) -> None:
"""
Print a formatted parameter budget analysis.
Args:
config: Model configuration.
limit: Parameter limit to compare against.
"""
estimates = estimate_parameters(config)
print("=" * 60)
print("PARAMETER BUDGET ANALYSIS")
print("=" * 60)
print(f"\nConfiguration:")
print(f" vocab_size (V) = {config.vocab_size}")
print(f" n_embd (d) = {config.n_embd}")
print(f" n_layer (L) = {config.n_layer}")
print(f" n_head = {config.n_head}")
print(f" n_ctx = {config.n_ctx}")
print(f" n_inner = {config.n_inner}")
print(f" tie_weights = {config.tie_weights}")
print(f"\nParameter Breakdown:")
print(f" Token Embeddings: {estimates['token_embeddings']:>10,}")
print(f" Position Embeddings: {estimates['position_embeddings']:>10,}")
print(f" Transformer Layers: {estimates['total_transformer_layers']:>10,}")
print(f" Final LayerNorm: {estimates['final_layernorm']:>10,}")
if config.tie_weights:
print(f" LM Head: {'(tied)':>10}")
else:
print(f" LM Head: {estimates['lm_head']:>10,}")
print(f" " + "-" * 30)
print(f" TOTAL: {estimates['total']:>10,}")
print(f"\nBudget Status:")
print(f" Limit: {limit:>10,}")
print(f" Used: {estimates['total']:>10,}")
print(f" Remaining:{limit - estimates['total']:>10,}")
if estimates['total'] <= limit:
print(f"\n Within budget! ({estimates['total'] / limit * 100:.1f}% used)")
else:
print(f"\n OVER BUDGET by {estimates['total'] - limit:,} parameters!")
print("=" * 60)
def validate_move_with_chess(move: str, board_fen: Optional[str] = None) -> bool:
"""
Validate a move using python-chess.
This function converts the dataset's extended UCI format to standard UCI
and validates it against the current board state.
Args:
move: Move in extended UCI format (e.g., "WPe2e4", "BNg8f6(x)").
board_fen: FEN string of the current board state (optional).
Returns:
True if the move is legal, False otherwise.
"""
try:
import chess
except ImportError:
raise ImportError("python-chess is required for move validation. "
"Install it with: pip install python-chess")
# Parse the extended UCI format
# Format: [W|B][Piece][from_sq][to_sq][suffix]
# Example: WPe2e4, BNg8f6(x), WKe1g1(o)
if len(move) < 6:
return False
# Extract components
color = move[0] # W or B
piece = move[1] # P, N, B, R, Q, K
from_sq = move[2:4] # e.g., "e2"
to_sq = move[4:6] # e.g., "e4"
# Check for promotion
promotion = None
if "=" in move:
promo_idx = move.index("=")
promotion = move[promo_idx + 1].lower()
# Create board
board = chess.Board(board_fen) if board_fen else chess.Board()
# Build UCI move string
uci_move = from_sq + to_sq
if promotion:
uci_move += promotion
try:
move_obj = chess.Move.from_uci(uci_move)
return move_obj in board.legal_moves
except (ValueError, chess.InvalidMoveError):
return False
def convert_extended_uci_to_uci(move: str) -> str:
"""
Convert extended UCI format to standard UCI format.
Args:
move: Move in extended UCI format (e.g., "WPe2e4").
Returns:
Move in standard UCI format (e.g., "e2e4").
"""
if len(move) < 6:
return move
# Extract squares
from_sq = move[2:4]
to_sq = move[4:6]
# Check for promotion
promotion = ""
if "=" in move:
promo_idx = move.index("=")
promotion = move[promo_idx + 1].lower()
return from_sq + to_sq + promotion
def convert_uci_to_extended(
uci_move: str,
board_fen: str,
) -> str:
"""
Convert standard UCI format to extended UCI format.
Args:
uci_move: Move in standard UCI format (e.g., "e2e4").
board_fen: FEN string of the current board state.
Returns:
Move in extended UCI format (e.g., "WPe2e4").
"""
try:
import chess
except ImportError:
raise ImportError("python-chess is required for move conversion.")
board = chess.Board(board_fen)
move = chess.Move.from_uci(uci_move)
# Get color
color = "W" if board.turn == chess.WHITE else "B"
# Get piece
piece = board.piece_at(move.from_square)
piece_letter = piece.symbol().upper() if piece else "P"
# Build extended UCI
from_sq = chess.square_name(move.from_square)
to_sq = chess.square_name(move.to_square)
result = f"{color}{piece_letter}{from_sq}{to_sq}"
# Add promotion
if move.promotion:
result += f"={chess.piece_symbol(move.promotion).upper()}"
# Add suffix for captures
if board.is_capture(move):
result += "(x)"
# Add suffix for check/checkmate
board.push(move)
if board.is_checkmate():
if "(x)" in result:
result = result.replace("(x)", "(x+*)")
else:
result += "(+*)"
elif board.is_check():
if "(x)" in result:
result = result.replace("(x)", "(x+)")
else:
result += "(+)"
board.pop()
# Handle castling notation
if board.is_castling(move):
if move.to_square in [chess.G1, chess.G8]: # Kingside
result = result.replace("(x)", "").replace("(+)", "") + "(o)"
else: # Queenside
result = result.replace("(x)", "").replace("(+)", "") + "(O)"
return result