""" Visualisation utils. """ import chess import chess.svg import matplotlib import matplotlib.pyplot as plt import numpy as np import torch COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000) ALPHA = 1.0 NORM = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False) def render_heatmap( board, heatmap, square=None, vmin=None, vmax=None, arrows=None, normalise="none", ): """ Render a heatmap on the board. """ if normalise == "abs": a_max = heatmap.abs().max() if a_max != 0: heatmap = heatmap / a_max vmin = -1 vmax = 1 if vmin is None: vmin = heatmap.min() if vmax is None: vmax = heatmap.max() norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False) color_dict = {} for square_index in range(64): color = COLOR_MAP(norm(heatmap[square_index])) color = (*color[:3], ALPHA) color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True) fig = plt.figure(figsize=(6, 0.6)) ax = plt.gca() ax.axis("off") fig.colorbar( matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP), ax=ax, orientation="horizontal", fraction=1.0, ) if square is not None: try: check = chess.parse_square(square) except ValueError: check = None else: check = None if arrows is None: arrows = [] plt.close() return ( chess.svg.board( board, check=check, fill=color_dict, size=350, arrows=arrows, ), fig, ) def render_policy_distribution( policy, legal_moves, n_bins=20, ): """ Render the policy distribution histogram. """ legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool() fig = plt.figure(figsize=(6, 6)) ax = plt.gca() _, bins = np.histogram(policy, bins=n_bins) ax.hist( policy[~legal_mask], bins=bins, alpha=0.5, density=True, label="Illegal moves", ) ax.hist( policy[legal_mask], bins=bins, alpha=0.5, density=True, label="Legal moves", ) plt.xlabel("Policy") plt.ylabel("Density") plt.legend() plt.yscale("log") return fig