""" Gradio interface for plotting attention. """ import chess import gradio as gr import torch import uuid import re from . import constants, state, visualisation def compute_cache( game_pgn, board_fen, attention_layer, attention_head, comp_index, state_cache, state_board_index, ): if game_pgn == "" and board_fen != "": board = chess.Board(board_fen) fen_list = [board.fen()] else: board = chess.Board() fen_list = [board.fen()] for move in game_pgn.split(): if move.endswith("."): continue try: board.push_san(move) fen_list.append(board.fen()) except ValueError: gr.Warning(f"Invalid move {move}, stopping before it.") break state_cache = [(fen, state.model_cache(fen)) for fen in fen_list] return ( *make_plot( attention_layer, attention_head, comp_index, state_cache, state_board_index ), state_cache, ) def make_plot( attention_layer, attention_head, comp_index, state_cache, state_board_index, ): if state_cache is None: gr.Warning("Cache not computed!") return None, None, None, None, None fen, (out, cache) = state_cache[state_board_index] attn_list = [a[0, attention_head - 1] for a in cache[attention_layer - 1]] prompt_attn, *comp_attn = attn_list comp_attn.insert(0, prompt_attn[-1:]) comp_attn = [a.squeeze(0) for a in comp_attn] if len(comp_attn) != 5: raise NotImplementedError("This is not implemented yet.") config_total = meta_total = dump_total = 0 config_done = False heatmap = torch.zeros(64) h_index = 0 for i, t_o in enumerate(out[0]): try: t_attn = comp_attn[comp_index - 1][i] if (i < 3) or (i > len(out[0]) - 10): dump_total += t_attn continue t_str = state.model.tokenizer.decode(t_o) if t_str.startswith(" ") and h_index > 0: config_done = True if not config_done: if t_str == "/": dump_total += t_attn continue t_str = re.sub(r"\d", lambda m: "0" * int(m.group(0)), t_str) config_total += t_attn t_str_len = len(t_str.strip()) pre_t_attn = t_attn / t_str_len for j in range(t_str_len): heatmap[h_index + j] = pre_t_attn h_index += t_str_len else: meta_total += t_attn except IndexError: break raw_attention = comp_attn[comp_index - 1] highlited_tokens = [ (state.model.tokenizer.decode(out[0][i]), raw_attention[i]) for i in range(len(raw_attention)) ] uci_move = state.model.tokenizer.decode(out[0][-5:-1]).strip() board = chess.Board(fen) heatmap = heatmap.view(8, 8).flip(0).view(64) move = chess.Move.from_uci(uci_move) svg_board, fig = visualisation.render_heatmap( board, heatmap, arrows=[(move.from_square, move.to_square)] ) info = ( f"[Completion] Complete: '{state.model.tokenizer.decode(out[0][-5:])}'" f" Chosen: '{state.model.tokenizer.decode(out[0][-5:][comp_index-1])}'" f"\n[Distribution] Config: {config_total:.2f} Meta: {meta_total:.2f} Dump: {dump_total:.2f}" ) id = str(uuid.uuid4()) with open(f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", "w") as f: f.write(svg_board) return ( board.fen(), info, fig, f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", highlited_tokens, ) def previous_board( attention_layer, attention_head, comp_index, state_cache, state_board_index, ): state_board_index -= 1 if state_board_index < 0: gr.Warning("Already at first board.") state_board_index = 0 return ( *make_plot( attention_layer, attention_head, comp_index, state_cache, state_board_index ), state_board_index, ) def next_board( attention_layer, attention_head, comp_index, state_cache, state_board_index, ): state_board_index += 1 if state_board_index >= len(state_cache): gr.Warning("Already at last board.") state_board_index = len(state_cache) - 1 return ( *make_plot( attention_layer, attention_head, comp_index, state_cache, state_board_index ), state_board_index, ) with gr.Blocks() as interface: with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown( "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." ) game_pgn = gr.Textbox( label="Game PGN", lines=1, ) board_fen = gr.Textbox( label="Board FEN", lines=1, max_lines=1, ) compute_cache_button = gr.Button("Compute cache") with gr.Group(): with gr.Row(): attention_layer = gr.Slider( label="Attention layer", minimum=1, maximum=12, step=1, value=1, ) attention_head = gr.Slider( label="Attention head", minimum=1, maximum=12, step=1, value=1, ) comp_index = gr.Slider( label="Completion index", minimum=1, maximum=6, step=1, value=1, ) with gr.Row(): previous_board_button = gr.Button("Previous board") next_board_button = gr.Button("Next board") current_board_fen = gr.Textbox( label="Board FEN", lines=1, max_lines=1, ) info = gr.Textbox( label="Info", lines=1, info=( "'Config' refers to the board configuration tokens." "\n'Meta' to the additional board tokens (like color or castling)." "\n'Dump' to the rest of the tokens (including '/')." ), ) gr.Markdown( "Note that only the 'Config' attention is plotted.\n\nSee below for the raw attention." ) raw_attention_html = gr.HighlightedText( label="Raw attention", ) with gr.Column(): image_board = gr.Image(label="Board") colorbar = gr.Plot(label="Colorbar") static_inputs = [ attention_layer, attention_head, comp_index, ] static_outputs = [ current_board_fen, info, colorbar, image_board, raw_attention_html, ] state_cache = gr.State(value=None) state_board_index = gr.State(value=0) compute_cache_button.click( compute_cache, inputs=[game_pgn, board_fen, *static_inputs, state_cache, state_board_index], outputs=[*static_outputs, state_cache], ) previous_board_button.click( previous_board, inputs=[*static_inputs, state_cache, state_board_index], outputs=[*static_outputs, state_board_index], ) next_board_button.click( next_board, inputs=[*static_inputs, state_cache, state_board_index], outputs=[*static_outputs, state_board_index], ) attention_layer.change( make_plot, inputs=[*static_inputs, state_cache, state_board_index], outputs=[*static_outputs], ) attention_head.change( make_plot, inputs=[*static_inputs, state_cache, state_board_index], outputs=[*static_outputs], ) comp_index.change( make_plot, inputs=[*static_inputs, state_cache, state_board_index], outputs=[*static_outputs], )