Spaces:
Running
Running
""" | |
Gradio interface for plotting attention. | |
""" | |
import chess | |
import chess.pgn | |
import io | |
import gradio as gr | |
from lczerolens.board import LczeroBoard | |
from ..constants import FIGURE_DIRECTORY | |
def make_render(game_pgn:str, board_fen:str, plane_index:int): | |
if game_pgn: | |
try: | |
board = LczeroBoard() | |
pgn = io.StringIO(game_pgn) | |
game = chess.pgn.read_game(pgn) | |
for move in game.mainline_moves(): | |
board.push(move) | |
except Exception as e: | |
print(e) | |
gr.Warning("Error parsing PGN, using starting position.") | |
board = LczeroBoard() | |
else: | |
try: | |
board = LczeroBoard(board_fen) | |
except Exception as e: | |
print(e) | |
gr.Warning("Invalid FEN, using starting position.") | |
board = LczeroBoard() | |
return board, *make_board_plot(board, plane_index) | |
def make_board_plot(board:LczeroBoard, plane_index:int): | |
input_tensor = board.to_input_tensor() | |
board.render_heatmap( | |
input_tensor[plane_index].view(64), | |
save_to=f"{FIGURE_DIRECTORY}/encodings.svg", | |
vmin=0, | |
vmax=1, | |
) | |
return f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg" | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
gr.Markdown( | |
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." | |
) | |
game_pgn = gr.Textbox( | |
label="Game PGN", | |
lines=1, | |
value="", | |
) | |
board_fen = gr.Textbox( | |
label="Board FEN", | |
lines=1, | |
max_lines=1, | |
value=chess.STARTING_FEN, | |
) | |
with gr.Group(): | |
with gr.Row(): | |
plane_index = gr.Slider( | |
label="Plane index", | |
minimum=0, | |
maximum=111, | |
step=1, | |
value=0, | |
) | |
with gr.Column(): | |
image_board = gr.Image(label="Board", interactive=False) | |
colorbar = gr.Image(label="Colorbar", interactive=False) | |
state_board = gr.State(value=LczeroBoard()) | |
render_inputs = [game_pgn, board_fen, plane_index] | |
render_outputs = [state_board, image_board, colorbar] | |
interface.load( | |
make_render, | |
inputs=render_inputs, | |
outputs=render_outputs, | |
) | |
game_pgn.submit( | |
make_render, | |
inputs=render_inputs, | |
outputs=render_outputs, | |
) | |
board_fen.submit( | |
make_render, | |
inputs=render_inputs, | |
outputs=render_outputs, | |
) | |
plane_index.change( | |
make_board_plot, | |
inputs=[state_board, plane_index], | |
outputs=[image_board, colorbar], | |
) | |