Xmaster6y's picture
info
980eda6
import gradio as gr
import chess.svg
from lczerolens import LczeroBoard, LczeroModel, Lens
from . import constants
def create_board_figure(
board: LczeroBoard,
*,
orientation: bool = chess.WHITE,
arrows: str = "",
square: str = "",
name: str = "board",
):
try:
if arrows:
arrows_list = arrows.split(" ")
chess_arrows = []
for arrow in arrows_list:
from_square, to_square = arrow[:2], arrow[2:]
chess_arrows.append(
(
chess.parse_square(from_square),
chess.parse_square(to_square),
)
)
else:
chess_arrows = []
except ValueError:
chess_arrows = []
gr.Warning("Invalid arrows, using none.")
try:
color_dict = {chess.parse_square(square): "#FF0000"} if square else {}
except ValueError:
color_dict = {}
gr.Warning("Invalid square, using none.")
svg_board = chess.svg.board(
board,
size=350,
orientation=orientation,
arrows=chess_arrows,
fill=color_dict,
)
with open(f"{constants.FIGURE_DIRECTORY}/{name}.svg", "w") as f:
f.write(svg_board)
return f"{constants.FIGURE_DIRECTORY}/{name}.svg"
class OutputLens(Lens):
def _intervene(self, model: LczeroModel, **kwargs) -> dict:
return model.output.save()
def get_info(model: LczeroModel, board: LczeroBoard):
lens = OutputLens()
output = lens.analyse(model, board)
w = output["wdl"][0,0]
d = output["wdl"][0,1]
l = output["wdl"][0,2]
legal_indices = board.get_legal_indices()
best_move_idx = output["policy"].gather(dim=1, index=legal_indices.unsqueeze(0)).argmax(dim=1).item()
best_move = board.decode_move(legal_indices[best_move_idx])
info = f"w: {w:.2f}, d: {d:.2f}, l: {l:.2f}, best: {best_move}"
return info