File size: 8,622 Bytes
09865df
868584e
09865df
 
 
 
4016c5c
868584e
09865df
d8e6f68
868584e
 
ba1e298
09865df
46f6011
e3868b1
09865df
d8e6f68
 
 
 
c97d40f
868584e
09865df
 
 
 
 
 
0254047
09865df
 
 
 
 
 
 
 
 
 
 
868584e
09865df
99333cf
09865df
 
0254047
868584e
 
 
 
 
 
 
09865df
 
 
 
 
 
868584e
 
 
 
09865df
 
 
 
 
 
 
868584e
 
 
 
 
 
 
 
 
 
09865df
 
d8e6f68
 
ad46e7d
09865df
 
0f2cd73
09865df
 
868584e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09865df
 
 
868584e
 
 
 
 
09865df
 
868584e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09865df
46f6011
868584e
 
 
 
 
 
 
 
 
 
 
09865df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import os
import sys

import chess.pgn
import gradio as gr
import torch
from loguru import logger
import numpy as np

from src.data.data_utils import clean_board
from src.engine.agents.policies import beam_search, eval_board, one_depth_eval
from src.engine.agents.viz_utils import plot_save_beam_search, save_svg, board_to_svg
from src.models.model_space import MultiInputConv

# TEMP_DIR = "./demos/temp/"
CHKPT = "checkpoint.pt"

file = sys.argv[0]
DIR_PATH = os.path.dirname(file)
TEMP_DIR = os.path.join(DIR_PATH, "temp")

chkpt = torch.load(os.path.join(DIR_PATH, f"checkpoints/{CHKPT}"))
model = MultiInputConv()
model.load_state_dict(state_dict=chkpt["model_state_dict"])
model.eval()

os.makedirs(name=TEMP_DIR, exist_ok=True)


@logger.catch(level="DEBUG", reraise=True)
def evaluate_board(board: chess.Board):
    """Evaluate the board.

    Args:
        board (chess.Board): chess.Board object

    Returns:
        float: score of the board

    """
    board = clean_board(board=board)
    save_svg(board=board, filename=os.path.join(TEMP_DIR, "board"), to_png=False)

    return os.path.join(TEMP_DIR, "board.svg"), eval_board(model=model, board=board)


@logger.catch(level="DEBUG", reraise=True)
def plot_beam_search(board: chess.Board,
                     depth: int,
                     beam_width: int,
                     player_strategy: str,
                     opponent_strategy: str,
                     player_top_k: int,
                     opponent_top_k: int):
    """Plot the beam search tree.

    Args:
        board (chess.Board): chess.Board object
        depth (int): depth of the search
        beam_width (int): width of the beam
        player_strategy (str): sampling strategy
        opponent_strategy (str): sampling strategy
        player_top_k (int): top-k value
        opponent_top_k (int): top-k value

    Returns:
        Image: image of the beam search tree

    """
    board = clean_board(board=board)

    beam = beam_search(model=model,
                       board=board,
                       depth=depth,
                       beam_width=beam_width,
                       player_strategy=player_strategy,
                       opponent_strategy=opponent_strategy,
                       player_top_k=player_top_k,
                       opponent_top_k=opponent_top_k,
                       min_score=-100,
                       max_score=100)
    plot_save_beam_search(
        beam=beam,
        filename=os.path.join(TEMP_DIR, "beam_search"),
        temp_dir=TEMP_DIR,
        intermediate_png=True,
    )

    return os.path.join(TEMP_DIR, "beam_search.png")


@logger.catch(level="DEBUG", reraise=True)
def get_one_depth_eval(board: chess.Board):
    """Get the legal boards from one-depth evaluation.

    Args:
        board (chess.Board): chess.Board object

    Returns:
        list: list of tuples of SVG images and scores of the legal boards
    """
    board = clean_board(board=board)

    legal_boards, legal_moves, scores = one_depth_eval(
        model=model, boards=[board], min_score=-100, max_score=100
    )

    # get scores argsort
    argsort = np.argsort(scores[0])
    if board.turn:
        argsort = argsort[::-1]

    scores = np.array(scores[0])[argsort]
    legal_boards = np.array(legal_boards[0])[argsort]
    legal_moves = np.array(legal_moves[0])[argsort]

    [save_svg(board=board, filename=os.path.join(TEMP_DIR, f"board_{i}"), to_png=False) for i, board in
     enumerate(legal_boards)]

    return (gr.update(value=[(os.path.join(TEMP_DIR, f"board_{i}.svg"), str(scores[i])) for i in range(len(legal_boards))]),
            gr.update(choices=[str(move) for move in legal_moves]),
            [str(move) for move in legal_moves])

@logger.catch(level="DEBUG", reraise=True)
def select_dropdown_item(moves, evt: gr.SelectData):
    """Select the nth item in the dropdown.

    Args:
        moves (list): list of moves
        evt (gr.EventData): event data
    """
    selected_index = evt.index
    return gr.update(value=moves[selected_index])

@logger.catch(level="DEBUG", reraise=True)
def update_run_fen(fen, dropdown):
    """Update the FEN board with the selected move.

    Args:
        fen (str): FEN board
        dropdown (str): selected move
    """
    board = clean_board(board=fen)
    board.push_san(dropdown)

    boards, update_dropdown, moves = get_one_depth_eval(board.fen())

    return board.fen(), boards, update_dropdown, moves

def update_top_k_visibility(strategy):
    return gr.update(visible=(strategy == "top-k"))


with gr.Blocks() as demo:
    gr.Markdown("Explore the model")
    with gr.Tab("Beam search"):
        with gr.Row():

            with gr.Column():

                board = gr.Textbox(
                    value="rnbqkbnr/ppp1pppp/8/3p4/2PP4/8/PP2PPPP/RNBQKBNR b KQkq - 0 2",
                    label="Provide FEN or PGN board:",
                )
                depth_slider = gr.Slider(value=4, minimum=1, maximum=10, step=1, label="Choose beam depth")
                width_slider = gr.Slider(value=4, minimum=1, maximum=10, step=1, label="Choose beam width")

                player_strategy = gr.Dropdown(
                    label="Select the player sampling strategy :",
                    choices=["greedy", "top-k"],
                    value="greedy",
                    interactive=True,
                    allow_custom_value=False,
                )

                opponent_strategy = gr.Dropdown(
                    label="Select the opponent sampling strategy :",
                    choices=["greedy", "top-k"],
                    value="greedy",
                    interactive=True,
                    allow_custom_value=False,
                )

                player_top_k = gr.Slider(value=5,
                                  minimum=5,
                                  maximum=20,
                                  step=1,
                                  label="Choose player top-k",
                                  interactive=True,
                                  visible=False)

                opponent_top_k = gr.Slider(value=2,
                                           minimum=2,
                                           maximum=20,
                                           step=1,
                                           label="Choose opponent top-k",
                                           interactive=True,
                                           visible=False)

                btn = gr.Button("Run beam search")

            with gr.Column():
                beam = gr.Image()

        player_strategy.change(fn=lambda x: gr.update(visible=(x == "top-k")), inputs=[player_strategy], outputs=[player_top_k])
        opponent_strategy.change(fn=lambda x: gr.update(visible=(x == "top-k")), inputs=[opponent_strategy], outputs=[opponent_top_k])
        width_slider.change(fn=lambda x: gr.update(minimum=x + 1, value=x + 1), inputs=[width_slider], outputs=[player_top_k])
        btn.click(fn=plot_beam_search,
                  inputs=[board, depth_slider, width_slider, player_strategy, opponent_strategy, player_top_k, opponent_top_k],
                  outputs=beam)


    with gr.Tab("One-depth eval"):
        moves = gr.State(value=[])

        board = gr.Textbox(
            value="rnbqkbnr/ppp1pppp/8/3p4/2PP4/8/PP2PPPP/RNBQKBNR b KQkq - 0 2",
            label="Provide FEN or PGN board:",
        )
        btn = gr.Button("Get one-depth evaluation")

        gallery = gr.Gallery(
            label="Legal boards from one-depth eval",
            show_label=False,
            elem_id="gallery",
            columns=6,
            interactive=False
        )

        dropdown = gr.Dropdown(
            label="Select the next move :",
            interactive=True,
        )

        btn_replace = gr.Button("Append selected move and run evaluation")

        btn.click(fn=get_one_depth_eval, inputs=[board], outputs=[gallery, dropdown, moves])
        gallery.select(fn=select_dropdown_item, inputs=[moves], outputs=dropdown)
        btn_replace.click(fn=update_run_fen, inputs=[board, dropdown], outputs=[board, gallery, dropdown, moves])

    with gr.Tab("Score a board"):
        gr.Interface(
            fn=evaluate_board,
            inputs=[
                gr.Textbox(
                    value="rnbqkbnr/ppp1pppp/8/3p4/2PP4/8/PP2PPPP/RNBQKBNR b KQkq - 0 2",
                    label="Provide FEN or PGN board:",
                ),
            ],
            outputs=["image", "text"],
            allow_flagging="never",
        )

if __name__ == "__main__":
    demo.launch()