Spaces:
Running
Running
File size: 1,962 Bytes
3333fb8 980eda6 3333fb8 980eda6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import gradio as gr
import chess.svg
from lczerolens import LczeroBoard, LczeroModel, Lens
from . import constants
def create_board_figure(
board: LczeroBoard,
*,
orientation: bool = chess.WHITE,
arrows: str = "",
square: str = "",
name: str = "board",
):
try:
if arrows:
arrows_list = arrows.split(" ")
chess_arrows = []
for arrow in arrows_list:
from_square, to_square = arrow[:2], arrow[2:]
chess_arrows.append(
(
chess.parse_square(from_square),
chess.parse_square(to_square),
)
)
else:
chess_arrows = []
except ValueError:
chess_arrows = []
gr.Warning("Invalid arrows, using none.")
try:
color_dict = {chess.parse_square(square): "#FF0000"} if square else {}
except ValueError:
color_dict = {}
gr.Warning("Invalid square, using none.")
svg_board = chess.svg.board(
board,
size=350,
orientation=orientation,
arrows=chess_arrows,
fill=color_dict,
)
with open(f"{constants.FIGURE_DIRECTORY}/{name}.svg", "w") as f:
f.write(svg_board)
return f"{constants.FIGURE_DIRECTORY}/{name}.svg"
class OutputLens(Lens):
def _intervene(self, model: LczeroModel, **kwargs) -> dict:
return model.output.save()
def get_info(model: LczeroModel, board: LczeroBoard):
lens = OutputLens()
output = lens.analyse(model, board)
w = output["wdl"][0,0]
d = output["wdl"][0,1]
l = output["wdl"][0,2]
legal_indices = board.get_legal_indices()
best_move_idx = output["policy"].gather(dim=1, index=legal_indices.unsqueeze(0)).argmax(dim=1).item()
best_move = board.decode_move(legal_indices[best_move_idx])
info = f"w: {w:.2f}, d: {d:.2f}, l: {l:.2f}, best: {best_move}"
return info
|