File size: 5,148 Bytes
e3868b1 ad46e7d e3868b1 ad46e7d e3868b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
import cairosvg
import chess.svg
import graphviz
from anytree import AnyNode
from loguru import logger
graph_params = {
"name": "Beam Search",
"format": "png",
"shape": "none",
"board_size": 300,
"fontsize": "20",
"image_pos": "tc",
"labelloc": "b",
"labelfontsize": "20",
"headport": "n",
"tailport": "s",
}
@logger.catch(level="DEBUG")
def board_to_svg(board: chess.Board, is_pruned: bool = False, size: int = 360) -> str:
"""Convert a chess.Board object to an SVG string.
Args:
board (chess.Board): chess.Board object
is_pruned (bool): whether the board is pruned or not
size (int): size of the board
Returns:
str: SVG string of the board
"""
svg = chess.svg.board(
board=board,
size=size,
lastmove=board.peek() if board.move_stack else None,
check=board.king(board.turn) if board.is_check() else None,
)
if is_pruned:
svg = svg.replace("#d18b47", "#989898")
svg = svg.replace("#ffce9e", "#d7d7d7")
svg = svg.replace("#cdd16a", "#c4c4c4")
svg = svg.replace("#aaa23b", "#999999")
return svg
@logger.catch(level="DEBUG")
def save_svg(
board: chess.Board, filename: str, is_pruned: bool = False, to_png=False
) -> None:
"""Save a chess.Board object to an image file.
Args:
board (chess.Board): chess.Board object
filename (str): filename to save the SVG file
is_pruned (bool): whether the board is pruned or not
to_png (bool): whether to save the SVG file as a PNG file
"""
xml_declaration = '<?xml version="1.0" encoding="UTF-8" standalone="no"?>'
with open(filename + ".svg", "w") as f:
f.write(
xml_declaration
+ board_to_svg(board, is_pruned=is_pruned, size=graph_params["board_size"])
)
if to_png:
cairosvg.svg2png(url=filename + ".svg", write_to=filename + ".png")
@logger.catch(level="DEBUG", reraise=True)
def plot_save_beam_search(
beam: AnyNode,
filename: str,
temp_dir: str = "../reports/figures",
intermediate_png: bool = False,
) -> None:
"""Plot and save the beam search tree.
Args:
beam (AnyNode): beam search tree
filename (str): filename to save the tree plot
temp_dir (str): temporary directory to save the SVG files
intermediate_png (bool): whether to save the SVG files as PNG files
"""
image_path = os.path.abspath(os.path.join(temp_dir, "root"))
# logger.info(image_path)
save_svg(board=beam.board, filename=image_path, to_png=intermediate_png)
dot = graphviz.Digraph(name=graph_params["name"], format=graph_params["format"])
image_path = image_path + ".png" if intermediate_png else image_path + ".svg"
# logger.info(image_path)
dot.node(
name="ROOT",
label="",
labelloc=graph_params["labelloc"],
fontsize=graph_params["fontsize"],
shape=graph_params["shape"],
image_pos=graph_params["image_pos"],
image=image_path,
)
for board in beam.descendants:
# a board is pruned if none of its children are at the same depth as the beam height
# and if it is not at the beam height itself
# and if it is not a terminal board itself or its children
is_pruned = (
not (any(board.depth == beam.height for board in board.descendants))
and not (any(board.board.outcome() for board in board.descendants))
and not (board.depth == beam.height)
and not (board.board.outcome()))
image_path = os.path.abspath(os.path.join(temp_dir, board.name))
# logger.info(image_path)
save_svg(
board=board.board,
filename=image_path,
is_pruned=is_pruned,
to_png=intermediate_png,
)
image_path = image_path + ".png" if intermediate_png else image_path + ".svg"
# logger.info(image_path)
dot.node(
name=board.name,
label="score = {:.4f}".format(board.score),
labelloc=graph_params["labelloc"],
fontsize=graph_params["fontsize"],
shape=graph_params["shape"],
image_pos=graph_params["image_pos"],
image=image_path,
)
dot.edge(
tail_name=board.parent.name,
head_name=board.name,
headlabel=board.move.uci(),
fontsize=graph_params["labelfontsize"],
headport=graph_params["headport"],
tailport=graph_params["tailport"],
)
dot.render(
filename=filename,
format=graph_params["format"],
# renderer="cairo",
# formatter="cairo",
)
#logger.info(os.listdir(temp_dir))
# log render file
#with open(filename, "r") as f:
# logger.info(f.read())
# delete all svg files
# for file in os.listdir(temp_dir):
# if file.endswith(".svg"):
# logger.info("Deleting : " + os.path.join(temp_dir, file))
# os.remove(os.path.join(temp_dir, file))
|