| |
|
|
| use rand::Rng; |
| use shakmaty::{ |
| Chess, Color, EnPassantMode, Move, MoveList, Piece, Position, Role, Square, |
| }; |
| use shakmaty::fen::Fen; |
|
|
| use crate::types::Termination; |
| use crate::vocab; |
|
|
| |
| #[inline] |
| pub fn our_sq_to_shakmaty(sq: u8) -> Square { |
| |
| |
| |
| let file = sq % 8; |
| let rank = sq / 8; |
| Square::from_coords( |
| shakmaty::File::new(file as u32), |
| shakmaty::Rank::new(rank as u32), |
| ) |
| } |
|
|
| |
| #[inline] |
| pub fn shakmaty_sq_to_ours(sq: Square) -> u8 { |
| let file = sq.file() as u8; |
| let rank = sq.rank() as u8; |
| rank * 8 + file |
| } |
|
|
| |
| pub fn move_to_token(m: &Move) -> u16 { |
| let (src, dst) = match m { |
| Move::Normal { from, to, .. } => (*from, *to), |
| Move::EnPassant { from, to } => (*from, *to), |
| Move::Castle { king, rook } => { |
| |
| let king_sq = *king; |
| let rook_sq = *rook; |
| let dst = if rook_sq.file() > king_sq.file() { |
| |
| Square::from_coords(shakmaty::File::G, king_sq.rank()) |
| } else { |
| |
| Square::from_coords(shakmaty::File::C, king_sq.rank()) |
| }; |
| (king_sq, dst) |
| } |
| Move::Put { .. } => panic!("Put moves not supported in standard chess"), |
| }; |
|
|
| let src_idx = shakmaty_sq_to_ours(src); |
| let dst_idx = shakmaty_sq_to_ours(dst); |
|
|
| |
| if let Move::Normal { promotion: Some(role), .. } = m { |
| let promo_type = match role { |
| Role::Queen => 0, |
| Role::Rook => 1, |
| Role::Bishop => 2, |
| Role::Knight => 3, |
| _ => panic!("Invalid promotion role: {:?}", role), |
| }; |
| vocab::promo_token(src_idx, dst_idx, promo_type) |
| .expect("Promotion move should have a valid promo pair") |
| } else { |
| vocab::base_grid_token(src_idx, dst_idx) |
| } |
| } |
|
|
| |
| |
| pub fn token_to_move(pos: &Chess, token: u16) -> Option<Move> { |
| |
| vocab::decompose_token(token)?; |
| let legal = pos.legal_moves(); |
|
|
| for m in &legal { |
| if move_to_token(m) == token { |
| return Some(m.clone()); |
| } |
| } |
|
|
| None |
| } |
|
|
| |
| |
| pub fn piece_to_code(piece: Option<Piece>) -> i8 { |
| match piece { |
| None => 0, |
| Some(p) => { |
| let base = match p.role { |
| Role::Pawn => 1, |
| Role::Knight => 2, |
| Role::Bishop => 3, |
| Role::Rook => 4, |
| Role::Queen => 5, |
| Role::King => 6, |
| }; |
| if p.color == Color::White { base } else { base + 6 } |
| } |
| } |
| } |
|
|
| |
| #[derive(Clone)] |
| pub struct GameState { |
| pos: Chess, |
| move_history: Vec<u16>, |
| position_hashes: Vec<u64>, |
| halfmove_clock: u32, |
| } |
|
|
| impl GameState { |
| pub fn new() -> Self { |
| let pos = Chess::default(); |
| let hash = Self::position_hash(&pos); |
| Self { |
| pos, |
| move_history: Vec::new(), |
| position_hashes: vec![hash], |
| halfmove_clock: 0, |
| } |
| } |
|
|
| |
| |
| fn position_hash(pos: &Chess) -> u64 { |
| use std::hash::{Hash, Hasher}; |
| use std::collections::hash_map::DefaultHasher; |
| let mut hasher = DefaultHasher::new(); |
|
|
| |
| for sq in Square::ALL { |
| let piece = pos.board().piece_at(sq); |
| piece.hash(&mut hasher); |
| } |
|
|
| |
| pos.turn().hash(&mut hasher); |
|
|
| |
| pos.castles().castling_rights().hash(&mut hasher); |
|
|
| |
| |
| pos.legal_ep_square().hash(&mut hasher); |
|
|
| hasher.finish() |
| } |
|
|
| pub fn position(&self) -> &Chess { |
| &self.pos |
| } |
|
|
| pub fn turn(&self) -> Color { |
| self.pos.turn() |
| } |
|
|
| pub fn is_white_to_move(&self) -> bool { |
| self.pos.turn() == Color::White |
| } |
|
|
| pub fn ply(&self) -> usize { |
| self.move_history.len() |
| } |
|
|
| pub fn move_history(&self) -> &[u16] { |
| &self.move_history |
| } |
|
|
| pub fn halfmove_clock(&self) -> u32 { |
| self.halfmove_clock |
| } |
|
|
| |
| pub fn legal_move_tokens(&self) -> Vec<u16> { |
| let legal = self.pos.legal_moves(); |
| legal.iter().map(|m| move_to_token(m)).collect() |
| } |
|
|
| |
| pub fn legal_moves(&self) -> MoveList { |
| self.pos.legal_moves() |
| } |
|
|
| |
| pub fn make_move(&mut self, token: u16) -> Result<(), String> { |
| let m = token_to_move(&self.pos, token) |
| .ok_or_else(|| format!("Token {} is not a legal move at ply {}", token, self.ply()))?; |
|
|
| |
| let is_pawn = match &m { |
| Move::Normal { role, .. } => *role == Role::Pawn, |
| Move::EnPassant { .. } => true, |
| Move::Castle { .. } => false, |
| Move::Put { .. } => false, |
| }; |
| let is_capture = m.is_capture(); |
|
|
| if is_pawn || is_capture { |
| self.halfmove_clock = 0; |
| } else { |
| self.halfmove_clock += 1; |
| } |
|
|
| self.pos.play_unchecked(m); |
| self.move_history.push(token); |
| let hash = Self::position_hash(&self.pos); |
| self.position_hashes.push(hash); |
|
|
| Ok(()) |
| } |
|
|
| |
| pub fn check_termination(&self, max_ply: usize) -> Option<Termination> { |
| let legal = self.pos.legal_moves(); |
|
|
| |
| |
| |
| if legal.is_empty() { |
| if self.pos.is_check() { |
| return Some(Termination::Checkmate); |
| } else { |
| return Some(Termination::Stalemate); |
| } |
| } |
|
|
| if self.ply() >= max_ply { |
| return Some(Termination::PlyLimit); |
| } |
|
|
| |
| if self.halfmove_clock >= 150 { |
| return Some(Termination::SeventyFiveMoveRule); |
| } |
|
|
| |
| if self.is_fivefold_repetition() { |
| return Some(Termination::FivefoldRepetition); |
| } |
|
|
| |
| if self.pos.is_insufficient_material() { |
| return Some(Termination::InsufficientMaterial); |
| } |
|
|
| None |
| } |
|
|
| pub fn is_fivefold_repetition(&self) -> bool { |
| let current = self.position_hashes.last().unwrap(); |
| let count = self.position_hashes.iter().filter(|h| *h == current).count(); |
| count >= 5 |
| } |
|
|
| |
| |
| pub fn legal_move_grid(&self) -> [u64; 64] { |
| let mut grid = [0u64; 64]; |
| let legal = self.pos.legal_moves(); |
|
|
| for m in &legal { |
| let token = move_to_token(m); |
| if let Some((src, dst, _promo)) = vocab::decompose_token(token) { |
| grid[src as usize] |= 1u64 << dst; |
| } |
| } |
|
|
| grid |
| } |
|
|
| |
| |
| pub fn legal_promo_mask(&self) -> [[bool; 4]; 44] { |
| let mut mask = [[false; 4]; 44]; |
| let legal = self.pos.legal_moves(); |
|
|
| for m in &legal { |
| if let Move::Normal { from, to, promotion: Some(role), .. } = m { |
| let src = shakmaty_sq_to_ours(*from); |
| let dst = shakmaty_sq_to_ours(*to); |
| if let Some(pair_idx) = vocab::promo_pair_index(src, dst) { |
| let promo_type = match role { |
| Role::Queen => 0, |
| Role::Rook => 1, |
| Role::Bishop => 2, |
| Role::Knight => 3, |
| _ => continue, |
| }; |
| mask[pair_idx][promo_type] = true; |
| } |
| } |
| } |
|
|
| mask |
| } |
|
|
| |
| pub fn board_array(&self) -> [[i8; 8]; 8] { |
| let mut board = [[0i8; 8]; 8]; |
| for rank in 0..8 { |
| for file in 0..8 { |
| let sq = Square::from_coords( |
| shakmaty::File::new(file as u32), |
| shakmaty::Rank::new(rank as u32), |
| ); |
| board[rank][file] = piece_to_code(self.pos.board().piece_at(sq)); |
| } |
| } |
| board |
| } |
|
|
| |
| pub fn castling_rights_bits(&self) -> u8 { |
| let rights = self.pos.castles().castling_rights(); |
| let mut bits = 0u8; |
| if rights.contains(Square::H1) { bits |= 1; } |
| if rights.contains(Square::A1) { bits |= 2; } |
| if rights.contains(Square::H8) { bits |= 4; } |
| if rights.contains(Square::A8) { bits |= 8; } |
| bits |
| } |
|
|
| |
| pub fn ep_square(&self) -> i8 { |
| match self.pos.legal_ep_square() { |
| Some(sq) => shakmaty_sq_to_ours(sq) as i8, |
| None => -1, |
| } |
| } |
|
|
| pub fn is_check(&self) -> bool { |
| self.pos.is_check() |
| } |
|
|
| |
| |
| |
| |
| |
| pub fn legal_moves_structured(&self) -> (Vec<u16>, Vec<(u16, Vec<u8>)>) { |
| let legal = self.pos.legal_moves(); |
| let mut grid_indices: Vec<u16> = Vec::with_capacity(legal.len()); |
| let mut promo_map: Vec<(u16, Vec<u8>)> = Vec::new(); |
| let mut seen_promo_flat: u16 = u16::MAX; |
|
|
| for m in &legal { |
| let token = move_to_token(m); |
| let (src, dst, promo) = vocab::decompose_token(token).unwrap(); |
| let flat_idx = (src as u16) * 64 + (dst as u16); |
|
|
| if promo == 0 { |
| grid_indices.push(flat_idx); |
| } else { |
| let pair_idx = vocab::promo_pair_index(src, dst).unwrap(); |
| let promo_type = promo - 1; |
|
|
| if flat_idx != seen_promo_flat { |
| |
| grid_indices.push(flat_idx); |
| promo_map.push((pair_idx as u16, vec![promo_type])); |
| seen_promo_flat = flat_idx; |
| } else { |
| |
| promo_map.last_mut().unwrap().1.push(promo_type); |
| } |
| } |
| } |
|
|
| (grid_indices, promo_map) |
| } |
|
|
| |
| pub fn legal_moves_grid_mask(&self) -> [bool; 4096] { |
| let legal = self.pos.legal_moves(); |
| let mut mask = [false; 4096]; |
| for m in &legal { |
| let token = move_to_token(m); |
| let (src, dst, _promo) = vocab::decompose_token(token).unwrap(); |
| let flat_idx = (src as usize) * 64 + (dst as usize); |
| mask[flat_idx] = true; |
| } |
| mask |
| } |
|
|
| |
| |
| |
| |
| pub fn legal_moves_full(&self) -> (Vec<u16>, Vec<(u16, Vec<u8>)>, [bool; 4096]) { |
| let legal = self.pos.legal_moves(); |
| let mut grid_indices: Vec<u16> = Vec::with_capacity(legal.len()); |
| let mut promo_map: Vec<(u16, Vec<u8>)> = Vec::new(); |
| let mut seen_promo_flat: u16 = u16::MAX; |
| let mut mask = [false; 4096]; |
|
|
| for m in &legal { |
| let token = move_to_token(m); |
| let (src, dst, promo) = vocab::decompose_token(token).unwrap(); |
| let flat_idx = (src as u16) * 64 + (dst as u16); |
|
|
| mask[flat_idx as usize] = true; |
|
|
| if promo == 0 { |
| grid_indices.push(flat_idx); |
| } else { |
| let pair_idx = vocab::promo_pair_index(src, dst).unwrap(); |
| let promo_type = promo - 1; |
|
|
| if flat_idx != seen_promo_flat { |
| grid_indices.push(flat_idx); |
| promo_map.push((pair_idx as u16, vec![promo_type])); |
| seen_promo_flat = flat_idx; |
| } else { |
| promo_map.last_mut().unwrap().1.push(promo_type); |
| } |
| } |
| } |
|
|
| (grid_indices, promo_map, mask) |
| } |
|
|
| |
| pub fn make_move_uci(&mut self, token: u16) -> Result<String, String> { |
| let uci = vocab::token_to_uci(token) |
| .ok_or_else(|| format!("Token {} has no UCI representation", token))?; |
| self.make_move(token)?; |
| Ok(uci) |
| } |
|
|
| |
| |
| pub fn uci_position_string(&self) -> String { |
| if self.move_history.is_empty() { |
| return "position startpos".to_string(); |
| } |
| let mut s = String::with_capacity(24 + self.move_history.len() * 6); |
| s.push_str("position startpos moves"); |
| for &token in &self.move_history { |
| s.push(' '); |
| s.push_str(&vocab::token_to_uci(token).unwrap()); |
| } |
| s |
| } |
|
|
| |
| pub fn fen(&self) -> String { |
| let setup = self.pos.to_setup(EnPassantMode::Legal); |
| let fen = Fen::try_from(setup).expect("valid position should produce valid FEN"); |
| fen.to_string() |
| } |
|
|
| |
| |
| pub fn make_random_move(&mut self, rng: &mut impl Rng) -> Option<u16> { |
| let legal = self.pos.legal_moves(); |
| if legal.is_empty() { |
| return None; |
| } |
| let idx = rng.gen_range(0..legal.len()); |
| let m = &legal[idx]; |
| let token = move_to_token(m); |
| |
| self.make_move(token).ok(); |
| Some(token) |
| } |
|
|
| |
| |
| pub fn from_move_tokens(tokens: &[u16]) -> Result<Self, String> { |
| let mut state = Self::new(); |
| for (i, &token) in tokens.iter().enumerate() { |
| state.make_move(token).map_err(|e| format!("ply {}: {}", i, e))?; |
| } |
| Ok(state) |
| } |
|
|
| |
| |
| pub fn play_random_to_end(&mut self, rng: &mut impl Rng, max_ply: usize) -> Termination { |
| loop { |
| if let Some(term) = self.check_termination(max_ply) { |
| return term; |
| } |
| if self.make_random_move(rng).is_none() { |
| return Termination::Stalemate; |
| } |
| } |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_square_conversion_roundtrip() { |
| for i in 0..64u8 { |
| let sq = our_sq_to_shakmaty(i); |
| assert_eq!(shakmaty_sq_to_ours(sq), i, "Roundtrip failed for {}", i); |
| } |
| } |
|
|
| #[test] |
| fn test_initial_legal_moves() { |
| let state = GameState::new(); |
| let tokens = state.legal_move_tokens(); |
| |
| assert_eq!(tokens.len(), 20, "Starting position should have 20 legal moves"); |
| } |
|
|
| #[test] |
| fn test_make_move() { |
| let mut state = GameState::new(); |
| |
| let token = vocab::base_grid_token(12, 28); |
| state.make_move(token).unwrap(); |
| assert_eq!(state.ply(), 1); |
| assert_eq!(state.turn(), Color::Black); |
| } |
|
|
| #[test] |
| fn test_legal_move_grid() { |
| let state = GameState::new(); |
| let grid = state.legal_move_grid(); |
| |
| let total: u32 = grid.iter().map(|g| g.count_ones()).sum(); |
| assert_eq!(total, 20); |
| } |
|
|
| #[test] |
| fn test_castling_token() { |
| |
| let src = shakmaty_sq_to_ours(Square::E1); |
| let dst = shakmaty_sq_to_ours(Square::G1); |
| assert_eq!(src, 4); |
| assert_eq!(dst, 6); |
| let token = vocab::base_grid_token(src, dst); |
| let uci = vocab::token_to_uci(token).unwrap(); |
| assert_eq!(uci, "e1g1"); |
| } |
| } |
|
|