""" Gradio interface for plotting policy. """ import chess import gradio as gr import uuid import torch from lczerolens.encodings import encode_move from src import constants, global_variables, visualisation def render_feature_index( file_id, feature_index ): if file_id is None: file_id = str(uuid.uuid4()) opt_features = global_variables.f_ds["opt_features"] f_acts = opt_features[:, feature_index] indices = f_acts.topk(16).indices board_images = [] colorbars = [] for topi, idx in enumerate(indices): s = global_variables.f_ds[idx.item()] pixel_index = global_variables.f_ds["pixel_index"][idx] features = [] for i in range(64): current_index = idx + i - pixel_index features.append(opt_features[current_index.item(), feature_index]) features = torch.stack(features) fen = s["opt_fen"] current_depth = s["current_depth"] uci_move = s["moves_opt"][current_depth + 6] move = chess.Move.from_uci(uci_move) board = chess.Board(fen) if board.turn: heatmap = features.view(64) else: heatmap = features.view(8, 8).flip(0).view(64) svg_board, fig = visualisation.render_heatmap( board, heatmap, arrows=[(move.from_square, move.to_square)], ) with open(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg", "w") as f: f.write(svg_board) board_images.append(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg") colorbars.append(fig) return file_id, *board_images, *colorbars with gr.Blocks() as interface: with gr.Row(): feature_index = gr.Slider( label="Feature index", minimum=0, maximum=constants.DICTIONARY_SIZE-1, step=1, value=0, ) board_images = [] colorbars = [] for i in range(4): with gr.Row(): for j in range(4): with gr.Column(): with gr.Group(): idx = 4*i + j with gr.Row(): board_images.append(gr.Image(label=f"Board {idx}")) with gr.Row(): colorbars.append(gr.Plot(label=f"Colorbar {idx}")) file_id = gr.State(None) feature_index.change( render_feature_index, inputs=[file_id, feature_index], outputs=[file_id, *board_images, *colorbars], )