|
import torch |
|
import torch.nn as nn |
|
from typing import List, Dict, Tuple |
|
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
from mapping import MAX_HALFMOVES, MAX_FULLMOVES, EMPTY_SQ_IDX, PIECE_TO_IDX, SQUARE_TO_IDX, IDX_TO_UCI_MOVE |
|
|
|
|
|
class FENTokenizer(nn.Module): |
|
"""Convert FEN (and repetitions) to a sequence of tokens""" |
|
def __init__(self, hidden_size,dtype): |
|
super().__init__() |
|
|
|
self.side_embed = nn.Embedding(2,hidden_size,dtype=dtype) |
|
|
|
self.castling_embed_k = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype)) |
|
self.castling_embed_q = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype)) |
|
self.castling_embed_K = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype)) |
|
self.castling_embed_Q = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype)) |
|
self.no_castling_embed = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype)) |
|
|
|
self.piece_embed = nn.Embedding(13,hidden_size,dtype=dtype) |
|
|
|
self.no_en_passant_embed = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype)) |
|
|
|
self.half_move_embed = nn.Embedding(MAX_HALFMOVES,hidden_size,dtype=dtype) |
|
|
|
self.full_move_embed = nn.Embedding(MAX_FULLMOVES,hidden_size,dtype=dtype) |
|
|
|
self.repetition_embed = nn.Embedding(3,hidden_size,dtype=dtype) |
|
|
|
self.pos_embed = nn.Embedding(64,hidden_size,dtype=dtype) |
|
|
|
def _parse_fen_string(self, fen_str: str) -> Dict: |
|
parts = fen_str.split() |
|
if len(parts) != 6: |
|
raise ValueError(f"Invalid FEN string: {fen_str}. Expected 6 fields") |
|
return { |
|
"piece_placement": parts[0], |
|
"side_to_move": parts[1], |
|
"castling": parts[2], |
|
"en_passant": parts[3], |
|
"halfmove_clock": parts[4], |
|
"fullmove_number": parts[5], |
|
} |
|
|
|
def forward(self, fen_list: List[str], repetitions: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
fen: List of fen strings |
|
|
|
Returns: |
|
torch tensor of shape (n_fen,73,hidden_size) where 73 tokens consists of: |
|
64 piece tokens (fen's first field) + |
|
1 which-side-to-move token (fen's second field) + |
|
4 casting rights tokens (fen's third field) + |
|
1 en-passant target token (fen's fourth field) + |
|
1 half move clock token (fen's fifth field) + |
|
1 full move number token (fen's fifth field) + |
|
1 repetition count token (repetitions input) |
|
""" |
|
batch_size = len(fen_list) |
|
assert batch_size == repetitions.shape[0] |
|
assert len(repetitions.size()) == 1 |
|
batch_tokens = [] |
|
device = self.side_embed.weight.device |
|
|
|
|
|
square_indices = torch.arange(64, device=device) |
|
all_pos_embeds = self.pos_embed(square_indices) |
|
|
|
for fen_str in fen_list: |
|
parsed_fen = self._parse_fen_string(fen_str) |
|
tokens = [] |
|
|
|
|
|
piece_indices = torch.full((64,), EMPTY_SQ_IDX, dtype=torch.long, device=device) |
|
current_rank = 7 |
|
current_file = 0 |
|
for char in parsed_fen["piece_placement"]: |
|
if char == '/': |
|
current_rank -= 1 |
|
current_file = 0 |
|
elif char.isdigit(): |
|
current_file += int(char) |
|
elif char in PIECE_TO_IDX: |
|
sq_idx = current_rank * 8 + current_file |
|
if 0 <= sq_idx < 64: |
|
piece_indices[sq_idx] = PIECE_TO_IDX[char] |
|
else: |
|
raise ValueError(f"Invalid FEN piece placement: {parsed_fen['piece_placement']}") |
|
current_file += 1 |
|
else: |
|
raise ValueError(f"Invalid character in FEN piece placement: {char}") |
|
|
|
piece_embeds = self.piece_embed(piece_indices) |
|
|
|
board_tokens = piece_embeds + all_pos_embeds |
|
tokens.append(board_tokens) |
|
|
|
|
|
side_idx = 0 if parsed_fen["side_to_move"] == 'w' else 1 |
|
side_token = self.side_embed(torch.tensor(side_idx, device=device)).unsqueeze(0) |
|
tokens.append(side_token) |
|
|
|
|
|
castling_str = parsed_fen["castling"] |
|
castling_tokens = torch.cat([ |
|
self.castling_embed_K if 'K' in castling_str else self.no_castling_embed.expand(1, 1, -1), |
|
self.castling_embed_Q if 'Q' in castling_str else self.no_castling_embed.expand(1, 1, -1), |
|
self.castling_embed_k if 'k' in castling_str else self.no_castling_embed.expand(1, 1, -1), |
|
self.castling_embed_q if 'q' in castling_str else self.no_castling_embed.expand(1, 1, -1) |
|
], dim=1).squeeze(0) |
|
tokens.append(castling_tokens) |
|
|
|
|
|
en_passant_str = parsed_fen["en_passant"] |
|
if en_passant_str == '-': |
|
en_passant_token = self.no_en_passant_embed.squeeze(0) |
|
else: |
|
if en_passant_str in SQUARE_TO_IDX: |
|
sq_idx = SQUARE_TO_IDX[en_passant_str] |
|
en_passant_token = self.pos_embed(torch.tensor(sq_idx, device=device)).unsqueeze(0) |
|
else: |
|
raise ValueError(f"Invalid en passant square: {en_passant_str}") |
|
tokens.append(en_passant_token) |
|
|
|
|
|
try: |
|
half_move_int = int(parsed_fen["halfmove_clock"]) |
|
except ValueError: |
|
raise ValueError(f"Invalid halfmove clock value: {parsed_fen['halfmove_clock']}") |
|
|
|
half_move_clamped = torch.clamp(torch.tensor(half_move_int, device=device), 0, MAX_HALFMOVES - 1) |
|
half_move_token = self.half_move_embed(half_move_clamped).unsqueeze(0) |
|
tokens.append(half_move_token) |
|
|
|
|
|
try: |
|
full_move_int = int(parsed_fen["fullmove_number"]) |
|
except ValueError: |
|
raise ValueError(f"Invalid fullmove number value: {parsed_fen['fullmove_number']}") |
|
|
|
full_move_clamped = torch.clamp(torch.tensor(full_move_int, device=device), 1, MAX_FULLMOVES) - 1 |
|
full_move_token = self.full_move_embed(full_move_clamped).unsqueeze(0) |
|
tokens.append(full_move_token) |
|
|
|
|
|
|
|
fen_embedding = torch.cat(tokens, dim=0) |
|
batch_tokens.append(fen_embedding) |
|
|
|
|
|
batch_tokens = torch.stack(batch_tokens, dim=0) |
|
|
|
|
|
repetitions = repetitions - 1 |
|
repetitions = torch.clamp(repetitions,0,2) |
|
repetition_tokens = self.repetition_embed(repetitions) |
|
repetition_tokens = repetition_tokens.unsqueeze(1) |
|
|
|
return torch.cat([batch_tokens,repetition_tokens], dim=1) |
|
|
|
|
|
class SwiGLUFFN(nn.Module): |
|
def __init__(self, |
|
d_model, |
|
dim_feedforward, |
|
dropout: float, |
|
bias_up: bool=False, |
|
bias_gate: bool=False, |
|
bias_down: bool=True, |
|
dtype=None): |
|
super().__init__() |
|
self.up_proj = nn.Linear(d_model,dim_feedforward,bias=bias_up,dtype=dtype) |
|
self.gate_proj = nn.Linear(d_model,dim_feedforward,bias=bias_gate,dtype=dtype) |
|
self.down_proj = nn.Linear(dim_feedforward,d_model,bias=bias_down,dtype=dtype) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
x = self.up_proj(x) * self.dropout(nn.functional.silu(self.gate_proj(x))) |
|
return self.down_proj(x) |
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
"""Custom transformer encoder layer with RMSNorm and SwiGLUFFN""" |
|
def __init__(self, |
|
d_model: int, |
|
nhead: int, |
|
dim_feedforward: int, |
|
dropout: float, |
|
batch_first: bool=True, |
|
norm_first: bool=False, |
|
dtype=None): |
|
super().__init__() |
|
self.norm_first = norm_first |
|
|
|
self.norm1 = nn.RMSNorm(d_model,dtype=dtype) |
|
self.dropout_sa = nn.Dropout(dropout) |
|
self.self_attn = nn.MultiheadAttention( |
|
d_model, |
|
nhead, |
|
dropout=dropout, |
|
bias=False, |
|
batch_first=batch_first, |
|
dtype=dtype |
|
) |
|
|
|
self.norm2 = nn.RMSNorm(d_model,dtype=dtype) |
|
self.dropout_ff = nn.Dropout(dropout) |
|
self.mlp = SwiGLUFFN( |
|
d_model, |
|
dim_feedforward, |
|
dropout=dropout, |
|
bias_up=False, |
|
bias_gate=False, |
|
bias_down=True, |
|
dtype=dtype |
|
) |
|
|
|
def forward(self, x, return_attention=False): |
|
if self.norm_first: |
|
if return_attention: |
|
x_norm = self.norm1(x) |
|
attn_output, attn_weights = self._sa_block(x_norm,return_attention=True) |
|
x = x + attn_output |
|
x = x + self._ff_block(self.norm2(x)) |
|
return x, attn_weights |
|
else: |
|
x = x + self._sa_block(self.norm1(x)) |
|
x = x + self._ff_block(self.norm2(x)) |
|
return x |
|
else: |
|
if return_attention: |
|
attn_output, attn_weights = self._sa_block(x, return_attention=True) |
|
x = self.norm1(x + attn_output) |
|
x = self.norm2(x + self._ff_block(x)) |
|
return x, attn_weights |
|
else: |
|
x = self.norm1(x + self._sa_block(x)) |
|
x = self.norm2(x + self._ff_block(x)) |
|
return x |
|
|
|
def _sa_block(self, x, return_attention=False): |
|
if return_attention: |
|
attn_output, attn_weights = self.self_attn(x,x,x,need_weights=True,average_attn_weights=False) |
|
return self.dropout_sa(attn_output), attn_weights |
|
else: |
|
x = self.self_attn(x,x,x)[0] |
|
return self.dropout_sa(x) |
|
|
|
def _ff_block(self,x): |
|
x = self.mlp(x) |
|
return self.dropout_ff(x) |
|
nn.TransformerEncoderLayer |
|
|
|
|
|
class ChessFormerModel(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, |
|
num_blocks, |
|
hidden_size, |
|
intermediate_size, |
|
num_heads, |
|
dropout: float=0.00, |
|
possible_moves: int=len(IDX_TO_UCI_MOVE), |
|
dtype=None): |
|
super().__init__() |
|
self.fen_tokenizer = FENTokenizer(hidden_size,dtype=dtype) |
|
|
|
self.act_token = nn.Parameter(torch.randn((1,1,hidden_size),dtype=dtype) * 0.02) |
|
self.val_token = nn.Parameter(torch.randn((1,1,hidden_size),dtype=dtype) * 0.02) |
|
|
|
self.act_proj = nn.Linear(hidden_size,possible_moves,dtype=dtype) |
|
self.val_proj = nn.Linear(hidden_size,1,dtype=dtype) |
|
|
|
self.blocks = nn.ModuleList( |
|
TransformerEncoderLayer( |
|
d_model=hidden_size, |
|
nhead=num_heads, |
|
dim_feedforward=intermediate_size, |
|
dropout=dropout, |
|
batch_first=True, |
|
norm_first=True, |
|
dtype=dtype |
|
) for _ in range(num_blocks) |
|
) |
|
self.dtype=dtype |
|
self.possible_moves = possible_moves |
|
|
|
self.final_norm = nn.RMSNorm(hidden_size) |
|
|
|
self._initialize_weights() |
|
|
|
def _initialize_weights(self): |
|
"""Initialize weights""" |
|
for m in self.modules(): |
|
if isinstance(m,nn.Linear): |
|
nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='relu') |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.Embedding): |
|
nn.init.normal_(m.weight, std=0.02) |
|
elif isinstance(m, nn.LayerNorm): |
|
if hasattr(m, 'weight'): |
|
nn.init.constant_(m.weight, 1.0) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
nn.init.constant_(m.weight, 0.0) |
|
elif isinstance(m, nn.RMSNorm): |
|
if hasattr(m, 'weight'): |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
tokenizer_params = dict(self.fen_tokenizer.named_parameters()) |
|
|
|
params_to_init = [ |
|
self.act_token, self.val_token, |
|
tokenizer_params.get('castling_embed_k'), tokenizer_params.get('castling_embed_q'), |
|
tokenizer_params.get('castling_embed_K'), tokenizer_params.get('castling_embed_Q'), |
|
tokenizer_params.get('no_castling_embed'), tokenizer_params.get('no_en_passant_embed') |
|
] |
|
|
|
for param in params_to_init: |
|
if param is not None and param.requires_grad: |
|
nn.init.normal_(param, std=0.02) |
|
|
|
|
|
def forward(self, fen: List[str], repetitions: torch.Tensor, return_attention: bool=False) -> torch.Tensor: |
|
x = self.fen_tokenizer(fen,repetitions) |
|
bs = x.shape[0] |
|
x = torch.cat([x,self.act_token.expand(bs,-1,-1),self.val_token.expand(bs,-1,-1)],dim=1) |
|
|
|
attention_maps = [] if return_attention else None |
|
|
|
for block in self.blocks: |
|
if return_attention: |
|
x, attn = block(x, return_attention=True) |
|
attention_maps.append(attn) |
|
else: |
|
x = block(x) |
|
|
|
x = self.final_norm(x) |
|
|
|
act = x[:,-2,:] |
|
val = x[:,-1,:] |
|
act_logits = self.act_proj(act) |
|
val = self.val_proj(val) |
|
|
|
if return_attention: |
|
return act_logits, val.squeeze(1), attention_maps |
|
else: |
|
return act_logits, val.squeeze(1) |
|
|
|
def load_model(ckpt_path): |
|
checkpoint = torch.load(ckpt_path) |
|
model_config = checkpoint["model_config"] |
|
model = ChessFormerModel(**model_config) |
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
return model |
|
|
|
if __name__ == "__main__": |
|
checkpoint = torch.load("./ckpts/chessformer-sl_13.pth",map_location=torch.device("cpu")) |
|
model = ChessFormerModel(**checkpoint["config"]) |
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
|
model.push_to_hub("kaupane/ChessFormer-SL") |