Xmaster6y's picture
new working demo
3333fb8
raw
history blame
3 kB
"""
Gradio interface for plotting attention.
"""
import chess
import chess.pgn
import io
import gradio as gr
from lczerolens.board import LczeroBoard
from ..constants import FIGURE_DIRECTORY
def make_render(game_pgn:str, board_fen:str, plane_index:int):
if game_pgn:
try:
board = LczeroBoard()
pgn = io.StringIO(game_pgn)
game = chess.pgn.read_game(pgn)
for move in game.mainline_moves():
board.push(move)
except Exception as e:
print(e)
gr.Warning("Error parsing PGN, using starting position.")
board = LczeroBoard()
else:
try:
board = LczeroBoard(board_fen)
except Exception as e:
print(e)
gr.Warning("Invalid FEN, using starting position.")
board = LczeroBoard()
return board, *make_board_plot(board, plane_index)
def make_board_plot(board:LczeroBoard, plane_index:int):
input_tensor = board.to_input_tensor()
board.render_heatmap(
input_tensor[plane_index].view(64),
save_to=f"{FIGURE_DIRECTORY}/encodings.svg",
vmin=0,
vmax=1,
)
return f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg"
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
)
game_pgn = gr.Textbox(
label="Game PGN",
lines=1,
value="",
)
board_fen = gr.Textbox(
label="Board FEN",
lines=1,
max_lines=1,
value=chess.STARTING_FEN,
)
with gr.Group():
with gr.Row():
plane_index = gr.Slider(
label="Plane index",
minimum=0,
maximum=111,
step=1,
value=0,
)
with gr.Column():
image_board = gr.Image(label="Board", interactive=False)
colorbar = gr.Image(label="Colorbar", interactive=False)
state_board = gr.State(value=LczeroBoard())
render_inputs = [game_pgn, board_fen, plane_index]
render_outputs = [state_board, image_board, colorbar]
interface.load(
make_render,
inputs=render_inputs,
outputs=render_outputs,
)
game_pgn.submit(
make_render,
inputs=render_inputs,
outputs=render_outputs,
)
board_fen.submit(
make_render,
inputs=render_inputs,
outputs=render_outputs,
)
plane_index.change(
make_board_plot,
inputs=[state_board, plane_index],
outputs=[image_board, colorbar],
)