|
import os |
|
import sys |
|
|
|
import chess.pgn |
|
import gradio as gr |
|
import torch |
|
from loguru import logger |
|
import numpy as np |
|
|
|
from src.data.data_utils import clean_board |
|
from src.engine.agents.policies import beam_search, eval_board, one_depth_eval |
|
from src.engine.agents.viz_utils import plot_save_beam_search, save_svg, board_to_svg |
|
from src.models.model_space import MultiInputConv |
|
|
|
|
|
CHKPT = "checkpoint.pt" |
|
|
|
file = sys.argv[0] |
|
DIR_PATH = os.path.dirname(file) |
|
TEMP_DIR = os.path.join(DIR_PATH, "temp") |
|
|
|
chkpt = torch.load(os.path.join(DIR_PATH, f"checkpoints/{CHKPT}")) |
|
model = MultiInputConv() |
|
model.load_state_dict(state_dict=chkpt["model_state_dict"]) |
|
model.eval() |
|
|
|
os.makedirs(name=TEMP_DIR, exist_ok=True) |
|
|
|
|
|
@logger.catch(level="DEBUG", reraise=True) |
|
def evaluate_board(board: chess.Board): |
|
"""Evaluate the board. |
|
|
|
Args: |
|
board (chess.Board): chess.Board object |
|
|
|
Returns: |
|
float: score of the board |
|
|
|
""" |
|
board = clean_board(board=board) |
|
save_svg(board=board, filename=os.path.join(TEMP_DIR, "board"), to_png=False) |
|
|
|
return os.path.join(TEMP_DIR, "board.svg"), eval_board(model=model, board=board) |
|
|
|
|
|
@logger.catch(level="DEBUG", reraise=True) |
|
def plot_beam_search(board: chess.Board, |
|
depth: int, |
|
beam_width: int, |
|
player_strategy: str, |
|
opponent_strategy: str, |
|
player_top_k: int, |
|
opponent_top_k: int): |
|
"""Plot the beam search tree. |
|
|
|
Args: |
|
board (chess.Board): chess.Board object |
|
depth (int): depth of the search |
|
beam_width (int): width of the beam |
|
player_strategy (str): sampling strategy |
|
opponent_strategy (str): sampling strategy |
|
player_top_k (int): top-k value |
|
opponent_top_k (int): top-k value |
|
|
|
Returns: |
|
Image: image of the beam search tree |
|
|
|
""" |
|
board = clean_board(board=board) |
|
|
|
beam = beam_search(model=model, |
|
board=board, |
|
depth=depth, |
|
beam_width=beam_width, |
|
player_strategy=player_strategy, |
|
opponent_strategy=opponent_strategy, |
|
player_top_k=player_top_k, |
|
opponent_top_k=opponent_top_k, |
|
min_score=-100, |
|
max_score=100) |
|
plot_save_beam_search( |
|
beam=beam, |
|
filename=os.path.join(TEMP_DIR, "beam_search"), |
|
temp_dir=TEMP_DIR, |
|
intermediate_png=True, |
|
) |
|
|
|
return os.path.join(TEMP_DIR, "beam_search.png") |
|
|
|
|
|
@logger.catch(level="DEBUG", reraise=True) |
|
def get_one_depth_eval(board: chess.Board): |
|
"""Get the legal boards from one-depth evaluation. |
|
|
|
Args: |
|
board (chess.Board): chess.Board object |
|
|
|
Returns: |
|
list: list of tuples of SVG images and scores of the legal boards |
|
""" |
|
board = clean_board(board=board) |
|
|
|
legal_boards, legal_moves, scores = one_depth_eval( |
|
model=model, boards=[board], min_score=-100, max_score=100 |
|
) |
|
|
|
|
|
argsort = np.argsort(scores[0]) |
|
if board.turn: |
|
argsort = argsort[::-1] |
|
|
|
scores = np.array(scores[0])[argsort] |
|
legal_boards = np.array(legal_boards[0])[argsort] |
|
legal_moves = np.array(legal_moves[0])[argsort] |
|
|
|
[save_svg(board=board, filename=os.path.join(TEMP_DIR, f"board_{i}"), to_png=False) for i, board in |
|
enumerate(legal_boards)] |
|
|
|
return (gr.update(value=[(os.path.join(TEMP_DIR, f"board_{i}.svg"), str(scores[i])) for i in range(len(legal_boards))]), |
|
gr.update(choices=[str(move) for move in legal_moves]), |
|
[str(move) for move in legal_moves]) |
|
|
|
@logger.catch(level="DEBUG", reraise=True) |
|
def select_dropdown_item(moves, evt: gr.SelectData): |
|
"""Select the nth item in the dropdown. |
|
|
|
Args: |
|
moves (list): list of moves |
|
evt (gr.EventData): event data |
|
""" |
|
selected_index = evt.index |
|
return gr.update(value=moves[selected_index]) |
|
|
|
@logger.catch(level="DEBUG", reraise=True) |
|
def update_run_fen(fen, dropdown): |
|
"""Update the FEN board with the selected move. |
|
|
|
Args: |
|
fen (str): FEN board |
|
dropdown (str): selected move |
|
""" |
|
board = clean_board(board=fen) |
|
board.push_san(dropdown) |
|
|
|
boards, update_dropdown, moves = get_one_depth_eval(board.fen()) |
|
|
|
return board.fen(), boards, update_dropdown, moves |
|
|
|
def update_top_k_visibility(strategy): |
|
return gr.update(visible=(strategy == "top-k")) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("Explore the model") |
|
with gr.Tab("Beam search"): |
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
|
|
board = gr.Textbox( |
|
value="rnbqkbnr/ppp1pppp/8/3p4/2PP4/8/PP2PPPP/RNBQKBNR b KQkq - 0 2", |
|
label="Provide FEN or PGN board:", |
|
) |
|
depth_slider = gr.Slider(value=4, minimum=1, maximum=10, step=1, label="Choose beam depth") |
|
width_slider = gr.Slider(value=4, minimum=1, maximum=10, step=1, label="Choose beam width") |
|
|
|
player_strategy = gr.Dropdown( |
|
label="Select the player sampling strategy :", |
|
choices=["greedy", "top-k"], |
|
value="greedy", |
|
interactive=True, |
|
allow_custom_value=False, |
|
) |
|
|
|
opponent_strategy = gr.Dropdown( |
|
label="Select the opponent sampling strategy :", |
|
choices=["greedy", "top-k"], |
|
value="greedy", |
|
interactive=True, |
|
allow_custom_value=False, |
|
) |
|
|
|
player_top_k = gr.Slider(value=5, |
|
minimum=5, |
|
maximum=20, |
|
step=1, |
|
label="Choose player top-k", |
|
interactive=True, |
|
visible=False) |
|
|
|
opponent_top_k = gr.Slider(value=2, |
|
minimum=2, |
|
maximum=20, |
|
step=1, |
|
label="Choose opponent top-k", |
|
interactive=True, |
|
visible=False) |
|
|
|
btn = gr.Button("Run beam search") |
|
|
|
with gr.Column(): |
|
beam = gr.Image() |
|
|
|
player_strategy.change(fn=lambda x: gr.update(visible=(x == "top-k")), inputs=[player_strategy], outputs=[player_top_k]) |
|
opponent_strategy.change(fn=lambda x: gr.update(visible=(x == "top-k")), inputs=[opponent_strategy], outputs=[opponent_top_k]) |
|
width_slider.change(fn=lambda x: gr.update(minimum=x + 1, value=x + 1), inputs=[width_slider], outputs=[player_top_k]) |
|
btn.click(fn=plot_beam_search, |
|
inputs=[board, depth_slider, width_slider, player_strategy, opponent_strategy, player_top_k, opponent_top_k], |
|
outputs=beam) |
|
|
|
|
|
with gr.Tab("One-depth eval"): |
|
moves = gr.State(value=[]) |
|
|
|
board = gr.Textbox( |
|
value="rnbqkbnr/ppp1pppp/8/3p4/2PP4/8/PP2PPPP/RNBQKBNR b KQkq - 0 2", |
|
label="Provide FEN or PGN board:", |
|
) |
|
btn = gr.Button("Get one-depth evaluation") |
|
|
|
gallery = gr.Gallery( |
|
label="Legal boards from one-depth eval", |
|
show_label=False, |
|
elem_id="gallery", |
|
columns=6, |
|
interactive=False |
|
) |
|
|
|
dropdown = gr.Dropdown( |
|
label="Select the next move :", |
|
interactive=True, |
|
) |
|
|
|
btn_replace = gr.Button("Append selected move and run evaluation") |
|
|
|
btn.click(fn=get_one_depth_eval, inputs=[board], outputs=[gallery, dropdown, moves]) |
|
gallery.select(fn=select_dropdown_item, inputs=[moves], outputs=dropdown) |
|
btn_replace.click(fn=update_run_fen, inputs=[board, dropdown], outputs=[board, gallery, dropdown, moves]) |
|
|
|
with gr.Tab("Score a board"): |
|
gr.Interface( |
|
fn=evaluate_board, |
|
inputs=[ |
|
gr.Textbox( |
|
value="rnbqkbnr/ppp1pppp/8/3p4/2PP4/8/PP2PPPP/RNBQKBNR b KQkq - 0 2", |
|
label="Provide FEN or PGN board:", |
|
), |
|
], |
|
outputs=["image", "text"], |
|
allow_flagging="never", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|