import io import traceback from typing import List import chess import chess.pgn import chess.svg import gradio as gr import numpy as np import tokenizers import torch from tokenizers import models, pre_tokenizers, processors from torch import Tensor as TT from transformers import (AutoModelForCausalLM, GPT2LMHeadModel, PreTrainedTokenizerFast) checkpoint_name = "austindavis/chess-gpt2-uci-8x8x512" class UciTokenizer(PreTrainedTokenizerFast): _PAD_TOKEN: str _UNK_TOKEN: str _EOS_TOKEN: str _BOS_TOKEN: str stoi: dict[str, int] """Integer to String mapping""" itos: dict[int, str] """String to Integer Mapping. This is the vocab""" def __init__( self, stoi, itos, pad_token, unk_token, bos_token, eos_token, name_or_path, ): self.stoi = stoi self.itos = itos self._PAD_TOKEN = pad_token self._UNK_TOKEN = unk_token self._EOS_TOKEN = eos_token self._BOS_TOKEN = bos_token # Define the model tok_model = models.WordLevel(vocab=self.stoi, unk_token=self._UNK_TOKEN) slow_tokenizer = tokenizers.Tokenizer(tok_model) slow_tokenizer.pre_tokenizer = self._init_pretokenizer() # post processing adds special tokens unless explicitly ignored post_proc = processors.TemplateProcessing( single=f"{bos_token} $0", pair=None, special_tokens=[(bos_token, 1)], ) slow_tokenizer.post_processor = post_proc super().__init__( tokenizer_object=slow_tokenizer, unk_token=self._UNK_TOKEN, bos_token=self._BOS_TOKEN, eos_token=self._EOS_TOKEN, pad_token=self._PAD_TOKEN, name_or_path=name_or_path, ) # Override the decode behavior to ensure spaces are correctly handled def _decode( token_ids: int | List[int], skip_special_tokens=False, clean_up_tokenization_spaces=False, ) -> int | List[int]: if isinstance(token_ids, int): return self.itos.get(token_ids, self._UNK_TOKEN) if isinstance(token_ids, dict): token_ids = token_ids["input_ids"] if isinstance(token_ids, TT): token_ids = token_ids.tolist() if isinstance(token_ids, list): tokens_str = [self.itos.get(xi, self._UNK_TOKEN) for xi in token_ids] moves = self._process_str_tokens(tokens_str) return " ".join(moves) self._decode = _decode def _init_pretokenizer(self) -> pre_tokenizers.PreTokenizer: raise NotImplementedError def _process_str_tokens(self, tokens_str: list[str]) -> list[str]: raise NotImplementedError def get_id2square_list() -> list[int]: raise NotImplementedError class UciTileTokenizer(UciTokenizer): """Uci tokenizer converting start/end tiles and promotion types each into individual tokens""" stoi = { tok: idx for tok, idx in list( zip( ["", "", "", ""] + chess.SQUARE_NAMES + list("qrbn"), range(72), ) ) } itos = { idx: tok for tok, idx in list( zip( ["", "", "", ""] + chess.SQUARE_NAMES + list("qrbn"), range(72), ) ) } id2square: List[int] = [None] * 4 + list(range(64)) + [None] * 4 """ List mapping token IDs to squares on the chess board. Order is file then row, i.e.: `A1, B1, C1, ..., F8, G8, H8` """ def get_id2square_list(self) -> List[int]: return self.id2square def __init__(self): super().__init__( self.stoi, self.itos, pad_token="", unk_token="", bos_token="", eos_token="", name_or_path="austindavis/uci_tile_tokenizer", ) def _init_pretokenizer(self): # Pre-tokenizer to split input into UCI moves pattern = tokenizers.Regex(r"\d") pre_tokenizer = pre_tokenizers.Sequence( [ pre_tokenizers.Whitespace(), pre_tokenizers.Split(pattern=pattern, behavior="merged_with_previous"), ] ) return pre_tokenizer def _process_str_tokens(self, token_str): moves = [] next_move = "" for token in token_str: # skip special tokens if token in self.all_special_tokens: continue # handle promotions if len(token) == 1: moves.append(next_move + token) continue # handle regular tokens if len(next_move) == 4: moves.append(next_move) next_move = token else: next_move += token moves.append(next_move) return moves def setup_app(model: GPT2LMHeadModel): """ Configures a Gradio App to use the GPT model for move generation. The model must be compatible with a UciTileTokenizer. """ tokenizer = UciTileTokenizer() # Initialize the chess board board = chess.Board() game: chess.pgn.GameNode = chess.pgn.Game() game.headers["Event"] = "Example" generate_kwargs = { "max_new_tokens": 3, "num_return_sequences": 10, "temperature": 0.5, "output_scores": True, "output_logits": True, "return_dict_in_generate": True, } def make_move(input: str, node=game, board=board): # check for reset if input.lower() == "reset": board.reset() node.root().variations.clear() return chess.svg.board(board=board), "New game!" # check for pgn if input[0] == "[" or input[:3] == "1. ": pgn = io.StringIO(input) game = chess.pgn.read_game(pgn) board.reset() node.root().variations.clear() for move in game.mainline_moves(): board.push(move) node.add_variation(move) return ( chess.svg.board(board=board, lastmove=move), "", ) # str(node.root()).split(']')[-1].strip() try: move = chess.Move.from_uci(input) if move in board.legal_moves: board.push(move) while node.next() is not None: node = node.next() node = node.add_variation(move) # get computer's move prefix = " ".join([x.uci() for x in board.move_stack]) encoding = tokenizer( text=prefix, return_tensors="pt", )["input_ids"] output = model.generate(encoding, **generate_kwargs) # [b,p,v] new_tokens = tokenizer.batch_decode(output.sequences[:, -3:]) unique_moves, unique_indices = np.unique( [x[:4] if " " in x else x for x in new_tokens], return_index=True ) unique_indices = ( torch.Tensor(list(unique_indices)) .to(dtype=torch.int) ) logits = torch.stack(output.logits) # [token, batch, vocab] logits = logits[:, unique_indices] # [token, batch, vocab] # select moves based on mean logit value for tokens 1 and 2 logit_priority_order = ( logits.max(dim=-1) .values.T[:, :2] .mean(-1) .topk(len(unique_indices)) .indices ) priority_ordered_moves = unique_moves[logit_priority_order] # if there's only 1 option, we have to pack it back into a list if isinstance(priority_ordered_moves, str): priority_ordered_moves = [priority_ordered_moves] # test if any moves are valid for uci in priority_ordered_moves: move = chess.Move.from_uci(uci) if move in board.legal_moves: board.push(move) while node.next() is not None: node = node.next() node = node.add_variation(move) return ( chess.svg.board(board=board, lastmove=move), "".join(str(node.root()).split("]")[-1]).strip(), ) # no moves are valid bad_from_tiles = [ chess.parse_square(x) for x in [x[:2] for x in unique_moves] ] bad_to_tiles = [ chess.parse_square(x) for x in [x[2:] for x in unique_moves] ] arrows = [ chess.svg.Arrow(tail, head, color="red") for (tail, head) in zip(bad_from_tiles, bad_to_tiles) ] checks = None if board.is_check(): checks = (board .pieces(chess.PIECE_TYPES[-1], board.turn) .pop() ) return chess.svg.board( board=board, arrows=arrows, check=checks ), "|".join(unique_moves) else: return ( chess.svg.board(board=board, lastmove=move), f"Illegal move: {input}", ) except chess.InvalidMoveError: return (chess.svg.board(board=board), f"Invalid UCI format: {input}") except Exception: return chess.svg.board(board=board), traceback.format_exc() input_box = gr.Textbox(None, placeholder="Enter your move in UCI format") # Define the Gradio interface iface = gr.Interface( fn=make_move, inputs=input_box, outputs=["html", "text"], examples=[["e2e4"], ["d2d4"], ["Reset"]], title="Play Versus ChessGPT", description="Enter moves in UCI notation (e.g., e2e4 for pawn from e2 \ to e4). Enter 'reset' to restart the game.", allow_flagging="never", submit_btn="Move", stop_btn="Stop", clear_btn="Clear w/o reset", ) iface.output_components[0].label = "Board" iface.output_components[0].show_label = True iface.output_components[1].label = "Move Sequence" return iface model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(checkpoint_name) model.requires_grad_(False) iface = setup_app(model) iface.launch(share=True)