File size: 3,913 Bytes
7277ab2
 
 
 
 
 
 
 
 
b8887b3
7277ab2
 
 
b8887b3
7277ab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8887b3
7277ab2
b8887b3
273af2d
7277ab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8887b3
 
 
 
 
273af2d
 
b8887b3
 
 
7277ab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8887b3
7277ab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8887b3
 
 
 
 
7277ab2
 
b8887b3
7277ab2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Gradio interface for plotting attention.
"""

import chess
import chess.pgn
import io
import gradio as gr

from lczerolens.board import LczeroBoard, InputEncoding

from ..constants import FIGURE_DIRECTORY

def make_render(game_pgn:str, board_fen:str, input_encoding:InputEncoding, plane_index:int):
    if game_pgn:
        try:
            board = LczeroBoard()
            pgn = io.StringIO(game_pgn)
            game = chess.pgn.read_game(pgn)
            for move in game.mainline_moves():
                board.push(move)
        except Exception as e:
            print(e)
            gr.Warning("Error parsing PGN, using starting position.")
            board = LczeroBoard()
    else:
        try:
            board = LczeroBoard(board_fen)
        except Exception as e:
            print(e)
            gr.Warning("Invalid FEN, using starting position.")
            board = LczeroBoard()
    return board, *make_board_plot(board, input_encoding, plane_index)

def make_board_plot(board:LczeroBoard, input_encoding:InputEncoding, plane_index:int):
    input_tensor = board.to_input_tensor(input_encoding=input_encoding)
    board.render_heatmap(
        input_tensor[plane_index].view(64),
        save_to=f"{FIGURE_DIRECTORY}/encodings.svg",
        vmin=0,
        vmax=1,
    )
    return  f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg"

with gr.Blocks() as interface:
    with gr.Row():
        with gr.Column():
            with gr.Group():
                gr.Markdown(
                    "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
                )
                game_pgn = gr.Textbox(
                    label="Game PGN",
                    lines=1,
                    value="",
                )
                board_fen = gr.Textbox(
                    label="Board FEN",
                    lines=1,
                    max_lines=1,
                    value=chess.STARTING_FEN,
                )
                input_encoding = gr.Radio(
                    label="Input encoding",
                    choices=[
                        ("classical", InputEncoding.INPUT_CLASSICAL_112_PLANE), 
                        ("repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED), 
                        ("no history repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED),
                        ("no history zeros", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS)
                    ],
                    value=InputEncoding.INPUT_CLASSICAL_112_PLANE,
                )
            with gr.Group():
                with gr.Row():
                    plane_index = gr.Slider(
                        label="Plane index",
                        minimum=0,
                        maximum=111,
                        step=1,
                        value=0,
                    )
        with gr.Column():
            image_board = gr.Image(label="Board", interactive=False)
            colorbar = gr.Image(label="Colorbar", interactive=False)

    state_board = gr.State(value=LczeroBoard())

    render_inputs = [game_pgn, board_fen, input_encoding, plane_index]
    render_outputs = [state_board, image_board, colorbar]
    interface.load(
        make_render,
        inputs=render_inputs,
        outputs=render_outputs,
    )
    game_pgn.submit(
        make_render,
        inputs=render_inputs,
        outputs=render_outputs,
    )
    board_fen.submit(
        make_render,
        inputs=render_inputs,
        outputs=render_outputs,
    )
    input_encoding.change(
        make_board_plot,
        inputs=[state_board, input_encoding, plane_index],
        outputs=[image_board, colorbar],
    )
    plane_index.change(
        make_board_plot,
        inputs=[state_board, input_encoding, plane_index],
        outputs=[image_board, colorbar],
    )