Upload 19 files
Browse files- .gitattributes +3 -0
- activation_heatmap.py +436 -0
- assets/DejaVuSans-Bold.ttf +3 -0
- assets/DejaVuSans.ttf +3 -0
- assets/DejavuSansMono-5m7L.ttf +3 -0
- assets/custom.css +59 -0
- assets/favicon.ico +0 -0
- board2planes.py +97 -0
- constants.py +24 -0
- datasets/test_set.csv +0 -0
- global_data.py +539 -0
- layouts/visualization_demo.slides.json +4 -0
- leela_board.py +617 -0
- leela_utils.py +0 -0
- models/.DS_Store +0 -0
- models/leela-large-official.onnx +3 -0
- python_chess_customized_svg.py +414 -0
- svg_pieces.py +31 -0
- utils.py +32 -0
- visualization_demo.py +230 -0
.gitattributes
CHANGED
|
@@ -36,3 +36,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
our_visualization/assets/DejaVuSans-Bold.ttf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
our_visualization/assets/DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
our_visualization/assets/DejavuSansMono-5m7L.ttf filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
our_visualization/assets/DejaVuSans-Bold.ttf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
our_visualization/assets/DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
our_visualization/assets/DejavuSansMono-5m7L.ttf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/DejaVuSans-Bold.ttf filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/DejavuSansMono-5m7L.ttf filter=lfs diff=lfs merge=lfs -text
|
activation_heatmap.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import chess
|
| 2 |
+
import dash
|
| 3 |
+
import plotly.graph_objs as go
|
| 4 |
+
from plotly.subplots import make_subplots
|
| 5 |
+
from global_data import global_data
|
| 6 |
+
|
| 7 |
+
from svg_pieces import get_svg_board
|
| 8 |
+
from dash import dcc, html, Input, Output, State
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
from plotly.io import to_json
|
| 16 |
+
|
| 17 |
+
import pickle
|
| 18 |
+
|
| 19 |
+
V_GAP = 0.15
|
| 20 |
+
LAYOUT_MARGIN_V = 40
|
| 21 |
+
|
| 22 |
+
def heatmap_data(head):
|
| 23 |
+
data = global_data.get_head_data(head)
|
| 24 |
+
return data
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def heatmap_figure():
|
| 28 |
+
if global_data.model is None:
|
| 29 |
+
return {}
|
| 30 |
+
start = time.time()
|
| 31 |
+
fig = make_figure()
|
| 32 |
+
print('make fig:', time.time() - start)
|
| 33 |
+
|
| 34 |
+
start = time.time()
|
| 35 |
+
fig = add_heatmap_traces(fig)
|
| 36 |
+
print('add traces:', time.time() - start)
|
| 37 |
+
with open("./test_activations_starting.pkl", 'wb') as f:
|
| 38 |
+
print("saving activations")
|
| 39 |
+
pickle.dump(global_data.activations, f)
|
| 40 |
+
start = time.time()
|
| 41 |
+
fig = add_layout(fig)
|
| 42 |
+
print('add layout total:', time.time() - start)
|
| 43 |
+
|
| 44 |
+
start = time.time()
|
| 45 |
+
|
| 46 |
+
if global_data.selected_layer == 'Smolgen':
|
| 47 |
+
with open('fig_as_json_no_pieces.json', 'w') as f:
|
| 48 |
+
f.write(to_json(fig, pretty=True))
|
| 49 |
+
|
| 50 |
+
if not global_data.visualization_mode_is_64x64:# and global_data.selected_layer != 'Smolgen':
|
| 51 |
+
fig = add_pieces(fig)
|
| 52 |
+
print('add pieces:', time.time() - start)
|
| 53 |
+
|
| 54 |
+
if global_data.selected_layer == 'Smolgen':
|
| 55 |
+
with open('fig_as_json.json', 'w') as f:
|
| 56 |
+
f.write(to_json(fig, pretty=True))
|
| 57 |
+
|
| 58 |
+
return fig
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def heatmap():
|
| 62 |
+
start = time.time()
|
| 63 |
+
# We need to recalculate graph when grid size changes, other wise layout is a mess (Dash bug?). Use hidden Div's children to trigger callback for graph recalc.
|
| 64 |
+
# Otherwise, we can just recalculate figure part and frontend rendering will be much faster
|
| 65 |
+
#
|
| 66 |
+
graph = html.Div(id='graph-container', children=[heatmap_graph()],
|
| 67 |
+
style={'height': '100%', 'width': '100%', "overflow": "auto"#, 'textAlign': 'center'#, "display": "flex", "justifyContent":"center"
|
| 68 |
+
})
|
| 69 |
+
print('GRAPH CREATION:', time.time() - start)
|
| 70 |
+
return graph
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def heatmap_graph():
|
| 74 |
+
fig = heatmap_figure()
|
| 75 |
+
|
| 76 |
+
config = {
|
| 77 |
+
'displaylogo': False,
|
| 78 |
+
'displayModeBar': True,
|
| 79 |
+
'modeBarButtonsToRemove': ['zoom', 'pan', 'select', 'zoomIn', 'zoomOut', 'autoScale', 'resetScale'],
|
| 80 |
+
'toImageButtonOptions': {
|
| 81 |
+
'format': global_data.export_format,
|
| 82 |
+
'scale': global_data.export_scale
|
| 83 |
+
}}
|
| 84 |
+
|
| 85 |
+
style = {'height': global_data.figure_container_height, 'width': '100%'}#, 'margin': '0 auto'}
|
| 86 |
+
|
| 87 |
+
graph = dcc.Graph(figure=fig, id='graph', style=style,
|
| 88 |
+
responsive='auto',#True, # True,
|
| 89 |
+
config=config
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# graph = html.Div(id='graph-container', children=[graph], style={'height': '100%', 'width': '100%', "overflow": "auto"
|
| 93 |
+
# })
|
| 94 |
+
# graph_component.children = [graph]
|
| 95 |
+
|
| 96 |
+
global_data.cache_figure(fig)
|
| 97 |
+
|
| 98 |
+
return graph
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def make_figure():
|
| 102 |
+
#print('assumed key', global_data.subplot_rows, global_data.subplot_cols, global_data.visualization_mode_is_64x64, global_data.selected_head if not global_data.show_all_heads else -1)
|
| 103 |
+
#print('key', global_data.get_figure_cache_key())
|
| 104 |
+
#print('all keys', global_data.figure_cache.keys())
|
| 105 |
+
fig = global_data.get_cached_figure()
|
| 106 |
+
if fig is None:
|
| 107 |
+
if global_data.show_all_heads:
|
| 108 |
+
titles = [f"Head {i + 1}" for i in range(global_data.number_of_heads)]
|
| 109 |
+
print('MAKING SUBPLOTS', 'rows:', global_data.subplot_rows, 'cols:', global_data.subplot_cols)
|
| 110 |
+
print('NUMBER OF HEADS:', global_data.number_of_heads)
|
| 111 |
+
fig = make_subplots(rows=global_data.subplot_rows, cols=global_data.subplot_cols, subplot_titles=titles,
|
| 112 |
+
horizontal_spacing=global_data.heatmap_horizontal_gap / global_data.subplot_cols,
|
| 113 |
+
vertical_spacing=V_GAP / global_data.subplot_rows,
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
print('CREATING 1x1')
|
| 117 |
+
titles = [f"head {global_data.selected_head +1}"]
|
| 118 |
+
fig = make_subplots(rows=1, cols=1, subplot_titles=titles)#go.Figure()#make_subplots(rows=1, cols=1, subplot_titles=titles)
|
| 119 |
+
|
| 120 |
+
return fig
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def add_layout(fig):
|
| 124 |
+
start = time.time()
|
| 125 |
+
|
| 126 |
+
#coloraxis1 = None
|
| 127 |
+
#if global_data.visualization_mode_is_64x64:
|
| 128 |
+
# if global_data.colorscale_mode == '3':
|
| 129 |
+
# coloraxis1 = {'colorscale': 'Viridis'}
|
| 130 |
+
|
| 131 |
+
coloraxis = None
|
| 132 |
+
if global_data.colorscale_mode == 'mode3':
|
| 133 |
+
cmin = np.amin(global_data.activations[:, :, :])
|
| 134 |
+
cmax = np.amax(global_data.activations[:, :, :])
|
| 135 |
+
coloraxis = {'colorscale': 'Viridis', 'colorbar': {'ypad': 0} , 'cmin': cmin, 'cmax': cmax, 'showscale': global_data.show_colorscale}
|
| 136 |
+
|
| 137 |
+
if global_data.check_if_figure_is_cached():
|
| 138 |
+
print('Using existing layout')
|
| 139 |
+
fig.update_layout({'coloraxis1': coloraxis}, overwrite=True)
|
| 140 |
+
return fig
|
| 141 |
+
|
| 142 |
+
layout = go.Layout(
|
| 143 |
+
# title='Plot title goes here',
|
| 144 |
+
margin={'t': LAYOUT_MARGIN_V, 'b': LAYOUT_MARGIN_V, 'r': 40, 'l': 40},
|
| 145 |
+
coloraxis1=coloraxis,
|
| 146 |
+
modebar={'orientation': 'v'}
|
| 147 |
+
#coloraxis={'colorscale': 'Viridis'}
|
| 148 |
+
#pa
|
| 149 |
+
#plot_bgcolor='rgb(0,0,0)',
|
| 150 |
+
#paper_bgcolor="black"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
fig.update_layout(layout)
|
| 154 |
+
# fig['layout'].update(layout)
|
| 155 |
+
|
| 156 |
+
print('update layout:', time.time() - start)
|
| 157 |
+
|
| 158 |
+
start = time.time()
|
| 159 |
+
fig = update_axis(fig)
|
| 160 |
+
print('update axis:', time.time() - start)
|
| 161 |
+
# print(fig)
|
| 162 |
+
return fig
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def update_axis(fig):
|
| 166 |
+
if global_data.visualization_mode_is_64x64:
|
| 167 |
+
tickvals_x = list(range(0, 64, 4))
|
| 168 |
+
tickvals_y = list(range(3, 67, 4))#list(range(0, 64, 4))#list(range(3, 67, 4))
|
| 169 |
+
if global_data.board.turn or global_data.selected_layer == 'Smolgen':
|
| 170 |
+
ticktext_x = [x + y for x, y in zip('ae' * 8, '1122334455667788')]
|
| 171 |
+
#tickvals = list(range(0, 64))
|
| 172 |
+
#ticktext_x = [x + y for x, y in zip('abcdefg' * 8, '1'*8 + '2'*8 + '3'*8 + '4'*8 + '5'*8 + '6'*8 + '7'*8 + '8'*8)]
|
| 173 |
+
ticktext_y = ticktext_x[::-1]
|
| 174 |
+
else:
|
| 175 |
+
ticktext_x = [x + y for x, y in zip('ae' * 8, '1122334455667788'[::-1])]
|
| 176 |
+
ticktext_y = ticktext_x[::-1]
|
| 177 |
+
showticklabels = True
|
| 178 |
+
#ticklabelstep = 4
|
| 179 |
+
val_range = [-0.5, 63.5]
|
| 180 |
+
ticks = 'outside'
|
| 181 |
+
title_x = {'text': "Keys ('to' square)", 'standoff': 1}
|
| 182 |
+
title_y = {'text': "Queries ('from' square)", 'standoff': 1}
|
| 183 |
+
else:
|
| 184 |
+
title_x = None
|
| 185 |
+
title_y = None
|
| 186 |
+
tickvals_x = list(range(8)) # [0, 1, 2, 3, 4, 5, 6, 7]
|
| 187 |
+
tickvals_y = tickvals_x
|
| 188 |
+
ticktext_x = [letter for letter in 'abcdefgh']
|
| 189 |
+
ticktext_y = [letter for letter in '12345678']
|
| 190 |
+
showticklabels = True
|
| 191 |
+
#ticklabelstep = 1
|
| 192 |
+
val_range = [-0.5, 7.5]
|
| 193 |
+
ticks = ''
|
| 194 |
+
|
| 195 |
+
if not global_data.show_all_heads or (global_data.subplot_cols == 1 and global_data.subplot_rows == 1):
|
| 196 |
+
constraintowards_x = 'center'
|
| 197 |
+
else:
|
| 198 |
+
constraintowards_x = 'right'
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
fig.update_xaxes(title=title_x,
|
| 202 |
+
range=val_range,
|
| 203 |
+
# ticklen=50,
|
| 204 |
+
zeroline=False,
|
| 205 |
+
showgrid=False,
|
| 206 |
+
scaleanchor='y',
|
| 207 |
+
constrain='domain',
|
| 208 |
+
constraintoward=constraintowards_x,
|
| 209 |
+
ticks=ticks, # ticks,
|
| 210 |
+
ticktext=ticktext_x,
|
| 211 |
+
tickvals=tickvals_x,
|
| 212 |
+
showticklabels=showticklabels,
|
| 213 |
+
# mirror='ticks',
|
| 214 |
+
fixedrange=True,
|
| 215 |
+
#ticklabelstep=ticklabelstep,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
fig.update_yaxes(title=title_y,
|
| 219 |
+
range=val_range,
|
| 220 |
+
zeroline=False,
|
| 221 |
+
showgrid=False,
|
| 222 |
+
scaleanchor='x',
|
| 223 |
+
constrain='domain',
|
| 224 |
+
constraintoward='top',
|
| 225 |
+
ticks=ticks, # ticks,
|
| 226 |
+
ticktext=ticktext_y,
|
| 227 |
+
tickvals=tickvals_y,
|
| 228 |
+
showticklabels=showticklabels,
|
| 229 |
+
# mirror='allticks',
|
| 230 |
+
# side='top',
|
| 231 |
+
fixedrange=True,
|
| 232 |
+
#ticklabelstep=ticklabelstep
|
| 233 |
+
)
|
| 234 |
+
return fig
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def calc_colorbar(row, col):
|
| 238 |
+
row = global_data.subplot_rows - row + 1 #invert
|
| 239 |
+
#row = global_data.subplot_rows - row - 1 #invert
|
| 240 |
+
|
| 241 |
+
dy = (1/global_data.subplot_rows)
|
| 242 |
+
dx = (1/global_data.subplot_cols)
|
| 243 |
+
|
| 244 |
+
offset = 1/global_data.subplot_cols - 2*(global_data.heatmap_horizontal_gap/(global_data.subplot_cols))/4#global_data.colorscale_x_offset#(494.1125)/2239.2#1/global_data.subplot_cols - 3*(global_data.heatmap_horizontal_gap/(global_data.subplot_cols))/4
|
| 245 |
+
|
| 246 |
+
if global_data.heatmap_h == 0:
|
| 247 |
+
len = (1 - V_GAP/global_data.subplot_rows) / global_data.subplot_rows #- #V_GAP/global_data.subplot_rows
|
| 248 |
+
lenmode = 'fraction'
|
| 249 |
+
offset2 = len / 2
|
| 250 |
+
else:
|
| 251 |
+
#total_h = global_data.heatmap_fig_h * global_data.heatmap_h + (global_data.subplot_rows - 1)
|
| 252 |
+
len = global_data.heatmap_h/(global_data.heatmap_fig_h - 2*LAYOUT_MARGIN_V)
|
| 253 |
+
lenmode = 'fraction'
|
| 254 |
+
#offset2 = len / 2
|
| 255 |
+
#lenmode = 'pixels'
|
| 256 |
+
#offset2 = 1 - len/(global_data.subplot_rows*len + V_GAP) #1/global_data.subplot_rows - (V_GAP/global_data.subplot_rows)
|
| 257 |
+
|
| 258 |
+
#tot_h = global_data.heatmap_fig_h - 2*LAYOUT_MARGIN_V
|
| 259 |
+
#max_h = ((1 - V_GAP)) / global_data.subplot_rows
|
| 260 |
+
#cur_h = len
|
| 261 |
+
offset2 = 1 - (global_data.subplot_rows-1)*(dy + (V_GAP / global_data.subplot_rows)/global_data.subplot_rows) - len/2 #0#len/2 #+ (max_h - cur_h)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
#offset = global_data.colorscale_x_offset
|
| 265 |
+
#shift = (global_data.heatmap_w + 20 + 20 + global_data.heatmap_gap)/global_data.heatmap_fig_w
|
| 266 |
+
#cx = (col - 1) * shift + offset
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
cx = (col-1)*(dx + (global_data.heatmap_horizontal_gap / global_data.subplot_cols)/global_data.subplot_cols) + offset
|
| 270 |
+
cy = (row-1)*(dy + (V_GAP / global_data.subplot_rows)/global_data.subplot_rows) + offset2
|
| 271 |
+
#cy = (global_data.subplot_rows - 1 - row) * (dy + (V_GAP / global_data.subplot_rows) / global_data.subplot_rows) + offset2
|
| 272 |
+
|
| 273 |
+
#######################
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
colorbar=dict(len=len, y=cy, x=cx, ypad=0, xpad=0, ticklabelposition='inside', ticks='inside', ticklen=3, lenmode=lenmode,
|
| 277 |
+
tickfont=dict(color='#7e807f'))
|
| 278 |
+
|
| 279 |
+
return colorbar
|
| 280 |
+
|
| 281 |
+
def add_heatmap_trace(fig, row, col, head=None):
|
| 282 |
+
# print('ADDING heatmap', row, col)
|
| 283 |
+
if head is None:
|
| 284 |
+
head = (row - 1) * global_data.subplot_cols + (col - 1)
|
| 285 |
+
data = heatmap_data(head)
|
| 286 |
+
|
| 287 |
+
if data is None:
|
| 288 |
+
return fig
|
| 289 |
+
|
| 290 |
+
if global_data.visualization_mode_is_64x64:
|
| 291 |
+
xgap, ygap = 0, 0
|
| 292 |
+
#hovertemplate = 'Query: <b>%{y}</b> <br> Key: <b>%{x}</b> <br> value: <b>%{z}</b><extra></extra>'
|
| 293 |
+
hovertemplate = 'Query: <b>%{customdata[0]}</b> <br>Key: <b>%{customdata[1]}</b> <br>value: <b>%{z:.5f}</b><extra></extra>'
|
| 294 |
+
if global_data.board.turn or global_data.selected_layer == 'Smolgen':
|
| 295 |
+
customdata_x = [[letter + ind for ind in '12345678' for letter in 'abcdefgh'] for _ in range(64)]
|
| 296 |
+
customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678'[::-1] for letter in 'abcdefgh'[::-1]]
|
| 297 |
+
else:
|
| 298 |
+
customdata_x = [[letter + ind for ind in '12345678'[::-1] for letter in 'abcdefgh'] for _ in range(64)]
|
| 299 |
+
customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678' for letter in 'abcdefgh'[::-1]]
|
| 300 |
+
|
| 301 |
+
#customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678' for letter in 'abcdefgh']
|
| 302 |
+
customdata = np.moveaxis([customdata_y, customdata_x], 0, -1)#[customdata_x, customdata_y]
|
| 303 |
+
|
| 304 |
+
else:
|
| 305 |
+
xgap, ygap = 2, 2
|
| 306 |
+
hovertemplate = '<b>%{x}%{y}</b>: <b>%{z}</b><extra></extra>'
|
| 307 |
+
customdata = None
|
| 308 |
+
|
| 309 |
+
coloraxis = None
|
| 310 |
+
#Colorscale
|
| 311 |
+
|
| 312 |
+
#if global_data.visualization_mode_is_64x64:
|
| 313 |
+
# if global_data.colorscale_mode == '3':
|
| 314 |
+
# coloraxis = 'coloraxis1'
|
| 315 |
+
|
| 316 |
+
coloraxis = None
|
| 317 |
+
colorscale = 'Viridis'
|
| 318 |
+
colorbar = None
|
| 319 |
+
#if global_data.show_colorscale and global_data.colorscale_mode == 'mode3':
|
| 320 |
+
if global_data.colorscale_mode == 'mode3':
|
| 321 |
+
coloraxis = 'coloraxis1'
|
| 322 |
+
colorscale = None
|
| 323 |
+
|
| 324 |
+
elif global_data.show_colorscale and not global_data.colorscale_mode == 'mode3' and not (not global_data.show_all_heads or (global_data.subplot_cols == 1 and global_data.subplot_rows == 1)):
|
| 325 |
+
colorbar = calc_colorbar(row, col)
|
| 326 |
+
|
| 327 |
+
zmin, zmax = None, None
|
| 328 |
+
|
| 329 |
+
if global_data.colorscale_mode == 'mode2':
|
| 330 |
+
pass
|
| 331 |
+
zmin = np.amin(global_data.activations[head, :, :])
|
| 332 |
+
zmax = np.amax(global_data.activations[head, :, :])
|
| 333 |
+
#print('ZMINMAX M2', head, zmin, zmax)
|
| 334 |
+
|
| 335 |
+
elif global_data.colorscale_mode == 'mode1':
|
| 336 |
+
zmin = np.amin(data)
|
| 337 |
+
zmax = np.amax(data)
|
| 338 |
+
|
| 339 |
+
#print('Trace data shape', data.shape)
|
| 340 |
+
trace = go.Heatmap(
|
| 341 |
+
z=data,
|
| 342 |
+
colorscale=colorscale,
|
| 343 |
+
showscale=global_data.show_colorscale,#True,
|
| 344 |
+
colorbar=colorbar,
|
| 345 |
+
#colorbar=dict(len=len, y=cy, x=cx, ypad=0, ticklabelposition='inside', ticks='inside', ticklen=3,
|
| 346 |
+
# tickfont=dict(color='#7e807f')),
|
| 347 |
+
xgap=xgap,
|
| 348 |
+
ygap=ygap,
|
| 349 |
+
hovertemplate=hovertemplate,
|
| 350 |
+
customdata=customdata,
|
| 351 |
+
zmin=zmin,
|
| 352 |
+
zmax=zmax,
|
| 353 |
+
coloraxis=coloraxis
|
| 354 |
+
#zmin=zmin,
|
| 355 |
+
#zmax=zmax
|
| 356 |
+
)
|
| 357 |
+
fig.add_trace(trace, row=row, col=col)
|
| 358 |
+
return fig
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def add_heatmap_traces(fig):
|
| 362 |
+
print('adding traces, rows:', global_data.subplot_rows, 'cols:', global_data.subplot_cols)
|
| 363 |
+
#adding traces is quick so we don't bother using cached values. Wipe old traces and add new.
|
| 364 |
+
fig.data = []
|
| 365 |
+
if global_data.show_all_heads:
|
| 366 |
+
for row in range(global_data.subplot_rows):
|
| 367 |
+
for col in range(global_data.subplot_cols):
|
| 368 |
+
fig = add_heatmap_trace(fig, row + 1, col + 1)
|
| 369 |
+
else:
|
| 370 |
+
fig = add_heatmap_trace(fig, 1, 1, global_data.selected_head)
|
| 371 |
+
return fig
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def add_pieces(fig):
|
| 375 |
+
if global_data.selected_layer != 'Smolgen':
|
| 376 |
+
board = global_data.board
|
| 377 |
+
|
| 378 |
+
else:
|
| 379 |
+
board = chess.Board(fen=None) #Empty board, we want to draw only the focused square
|
| 380 |
+
board_svg = get_svg_board(board, global_data.focused_square_ind, True)
|
| 381 |
+
|
| 382 |
+
images = [dict(
|
| 383 |
+
source=board_svg,
|
| 384 |
+
xref="x"+str(i),
|
| 385 |
+
yref="y"+str(i),
|
| 386 |
+
x=3.5,
|
| 387 |
+
y=3.5,
|
| 388 |
+
sizex=8,
|
| 389 |
+
sizey=8,
|
| 390 |
+
xanchor='center',
|
| 391 |
+
yanchor='middle',
|
| 392 |
+
sizing="stretch",
|
| 393 |
+
)
|
| 394 |
+
for i in range(2, 2+255)
|
| 395 |
+
]
|
| 396 |
+
images = [dict(
|
| 397 |
+
source=board_svg,
|
| 398 |
+
xref="x",
|
| 399 |
+
yref="y",
|
| 400 |
+
x=3.5,
|
| 401 |
+
y=3.5,
|
| 402 |
+
sizex=8,
|
| 403 |
+
sizey=8,
|
| 404 |
+
xanchor='center',
|
| 405 |
+
yanchor='middle',
|
| 406 |
+
sizing="stretch",
|
| 407 |
+
)] + images
|
| 408 |
+
|
| 409 |
+
fig.layout.images = images
|
| 410 |
+
return fig
|
| 411 |
+
board_svg = get_svg_board(board, global_data.focused_square_ind, True)
|
| 412 |
+
if global_data.check_if_figure_is_cached():
|
| 413 |
+
print('USING CACHED')
|
| 414 |
+
for img in fig.layout.images:
|
| 415 |
+
img['source'] = board_svg
|
| 416 |
+
else:
|
| 417 |
+
fig.add_layout_image(
|
| 418 |
+
dict(
|
| 419 |
+
source=board_svg,
|
| 420 |
+
xref="x",
|
| 421 |
+
yref="y",
|
| 422 |
+
x=3.5,
|
| 423 |
+
y=3.5,
|
| 424 |
+
sizex=8,
|
| 425 |
+
sizey=8,
|
| 426 |
+
xanchor='center',
|
| 427 |
+
yanchor='middle',
|
| 428 |
+
sizing="stretch",
|
| 429 |
+
),
|
| 430 |
+
row='all',
|
| 431 |
+
col='all',
|
| 432 |
+
exclude_empty_subplots=True,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
return fig
|
| 436 |
+
|
assets/DejaVuSans-Bold.ttf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6476c1b80502924294eed40894c5b18e06c181444ca953e5334262df9c27724
|
| 3 |
+
size 705684
|
assets/DejaVuSans.ttf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7da195a74c55bef988d0d48f9508bd5d849425c1770dba5d7bfc6ce9ed848954
|
| 3 |
+
size 757076
|
assets/DejavuSansMono-5m7L.ttf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dfbac4c793ca4b896c34af5c7ccbbaf37924b46cd521a5ecff9130b9c331f575
|
| 3 |
+
size 333636
|
assets/custom.css
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*[class="model-loading"][data-dash-is-loading="true"]::before{
|
| 2 |
+
content: "Loading model...";
|
| 3 |
+
display: inline-block;
|
| 4 |
+
color: red;
|
| 5 |
+
visibility: visible;
|
| 6 |
+
font-size: 16px;
|
| 7 |
+
/*transition-delay: 1s;
|
| 8 |
+
transition-property: font-size;
|
| 9 |
+
transition-duration: 200ms;*/
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
.completely-hidden{
|
| 13 |
+
display: none;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
.hidden-but-reserve-space{
|
| 17 |
+
visibility: hidden;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
#link{
|
| 21 |
+
text-decoration: underline;
|
| 22 |
+
cursor: pointer;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
.header-container {
|
| 26 |
+
border-bottom: thin lightgrey solid;
|
| 27 |
+
box-sizing: border-box;
|
| 28 |
+
white-space: nowrap;
|
| 29 |
+
overflow-y: auto;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.header-control-container {
|
| 33 |
+
margin-left: 20px;
|
| 34 |
+
margin-right: 20px;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
body {
|
| 38 |
+
margin: 0;
|
| 39 |
+
padding: 0;
|
| 40 |
+
font-family: 'BundledDejavuSans';
|
| 41 |
+
font-size: 14px;
|
| 42 |
+
-moz-user-select: none;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
@font-face {
|
| 46 |
+
font-family: 'BundledDejavuSansMono';
|
| 47 |
+
src: url('DejavuSansMono-5m7L.ttf');
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
@font-face {
|
| 51 |
+
font-family: 'BundledDejavuSans';
|
| 52 |
+
src: url('DejaVuSans.ttf');
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
@font-face {
|
| 56 |
+
font-family: 'BundledDejavuSans';
|
| 57 |
+
src: url('DejaVuSans-Bold.ttf');
|
| 58 |
+
font-weight: bold;
|
| 59 |
+
}
|
assets/favicon.ico
ADDED
|
|
board2planes.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Author: https://github.com/Arcturai
|
| 2 |
+
|
| 3 |
+
import chess
|
| 4 |
+
import numpy as np
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
WPAWN = chess.Piece(chess.PAWN, chess.WHITE)
|
| 9 |
+
WKNIGHT = chess.Piece(chess.KNIGHT, chess.WHITE)
|
| 10 |
+
WBISHOP = chess.Piece(chess.BISHOP, chess.WHITE)
|
| 11 |
+
WROOK = chess.Piece(chess.ROOK, chess.WHITE)
|
| 12 |
+
WQUEEN = chess.Piece(chess.QUEEN, chess.WHITE)
|
| 13 |
+
WKING = chess.Piece(chess.KING, chess.WHITE)
|
| 14 |
+
BPAWN = chess.Piece(chess.PAWN, chess.BLACK)
|
| 15 |
+
BKNIGHT = chess.Piece(chess.KNIGHT, chess.BLACK)
|
| 16 |
+
BBISHOP = chess.Piece(chess.BISHOP, chess.BLACK)
|
| 17 |
+
BROOK = chess.Piece(chess.ROOK, chess.BLACK)
|
| 18 |
+
BQUEEN = chess.Piece(chess.QUEEN, chess.BLACK)
|
| 19 |
+
BKING = chess.Piece(chess.KING, chess.BLACK)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def assign_piece2(planes, piece_step, row, col):
|
| 23 |
+
planes[piece_step][row][col] = 1
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
DISPATCH2 = {}
|
| 27 |
+
|
| 28 |
+
DISPATCH2[str(WPAWN)] = lambda retval, row, col: assign_piece2(retval, 0, row, col)
|
| 29 |
+
DISPATCH2[str(WKNIGHT)] = lambda retval, row, col: assign_piece2(retval, 1, row, col)
|
| 30 |
+
DISPATCH2[str(WBISHOP)] = lambda retval, row, col: assign_piece2(retval, 2, row, col)
|
| 31 |
+
DISPATCH2[str(WROOK)] = lambda retval, row, col: assign_piece2(retval, 3, row, col)
|
| 32 |
+
DISPATCH2[str(WQUEEN)] = lambda retval, row, col: assign_piece2(retval, 4, row, col)
|
| 33 |
+
DISPATCH2[str(WKING)] = lambda retval, row, col: assign_piece2(retval, 5, row, col)
|
| 34 |
+
DISPATCH2[str(BPAWN)] = lambda retval, row, col: assign_piece2(retval, 6, row, col)
|
| 35 |
+
DISPATCH2[str(BKNIGHT)] = lambda retval, row, col: assign_piece2(retval, 7, row, col)
|
| 36 |
+
DISPATCH2[str(BBISHOP)] = lambda retval, row, col: assign_piece2(retval, 8, row, col)
|
| 37 |
+
DISPATCH2[str(BROOK)] = lambda retval, row, col: assign_piece2(retval, 9, row, col)
|
| 38 |
+
DISPATCH2[str(BQUEEN)] = lambda retval, row, col: assign_piece2(retval, 10, row, col)
|
| 39 |
+
DISPATCH2[str(BKING)] = lambda retval, row, col: assign_piece2(retval, 11, row, col)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def append_plane(planes, ones):
|
| 43 |
+
if ones:
|
| 44 |
+
return np.append(planes, np.ones((1, 8, 8), dtype=np.float), axis=0)
|
| 45 |
+
else:
|
| 46 |
+
return np.append(planes, np.zeros((1, 8, 8), dtype=np.float), axis=0)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def fill_planes(board):
|
| 50 |
+
planes = np.zeros((12, 8, 8), dtype=np.float)
|
| 51 |
+
for row in range(8):
|
| 52 |
+
for col in range(8):
|
| 53 |
+
piece = str(board.piece_at(chess.SQUARES[row * 8 + col]))
|
| 54 |
+
if piece != "None":
|
| 55 |
+
DISPATCH2[piece](planes, row, col)
|
| 56 |
+
planes = append_plane(planes, board.is_repetition(2))
|
| 57 |
+
return planes
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def board2planes(board_):
|
| 61 |
+
if not board_.turn:
|
| 62 |
+
board = board_.mirror()
|
| 63 |
+
else:
|
| 64 |
+
board = board_
|
| 65 |
+
|
| 66 |
+
retval = fill_planes(board)
|
| 67 |
+
|
| 68 |
+
s_board = board_.copy()
|
| 69 |
+
for i in range(7):
|
| 70 |
+
if s_board.move_stack.__len__() > 0:
|
| 71 |
+
s_board.pop()
|
| 72 |
+
b = s_board.mirror() if not board_.turn else s_board.copy()
|
| 73 |
+
retval = np.append(retval, fill_planes(b), axis=0)
|
| 74 |
+
else:
|
| 75 |
+
retval = np.append(retval, np.zeros((13, 8, 8), dtype=np.float), axis=0)
|
| 76 |
+
|
| 77 |
+
retval = append_plane(retval, bool(board.castling_rights & chess.BB_H1))
|
| 78 |
+
retval = append_plane(retval, bool(board.castling_rights & chess.BB_A1))
|
| 79 |
+
retval = append_plane(retval, bool(board.castling_rights & chess.BB_H8))
|
| 80 |
+
retval = append_plane(retval, bool(board.castling_rights & chess.BB_A8))
|
| 81 |
+
retval = append_plane(retval, not board_.turn)
|
| 82 |
+
retval = np.append(retval, np.full((1, 8, 8), fill_value=board_.halfmove_clock/99., dtype=np.float), axis=0)
|
| 83 |
+
retval = append_plane(retval, False)
|
| 84 |
+
retval = append_plane(retval, True)
|
| 85 |
+
|
| 86 |
+
return np.expand_dims(retval, axis=0)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def bulk_board2planes(boards):
|
| 90 |
+
planes = []
|
| 91 |
+
for b in boards:
|
| 92 |
+
temp = board2planes(b)
|
| 93 |
+
planes.append(temp)
|
| 94 |
+
pl = tuple(planes)
|
| 95 |
+
retval = np.concatenate(pl, axis=0)
|
| 96 |
+
return retval
|
| 97 |
+
|
constants.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def root_directory():
|
| 6 |
+
if getattr(sys, 'frozen', False):
|
| 7 |
+
# The application is frozen
|
| 8 |
+
root = os.path.dirname(sys.executable)
|
| 9 |
+
else:
|
| 10 |
+
# The application is not frozen
|
| 11 |
+
root = os.path.dirname(__file__) # os.path.dirname(os.path.abspath(__file__))#os.path.dirname(__file__)
|
| 12 |
+
return (root)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
ROOT_DIR = root_directory()
|
| 16 |
+
|
| 17 |
+
LEFT_PANE_WIDTH = 90
|
| 18 |
+
RIGHT_PANE_WIDTH = 100 - LEFT_PANE_WIDTH
|
| 19 |
+
GRAPH_PANE_HEIGHT = 100
|
| 20 |
+
HEADER_HEIGHT = 11
|
| 21 |
+
CONTENT_HEIGHT = 100 - HEADER_HEIGHT
|
| 22 |
+
|
| 23 |
+
EXPORT_FORMAT = 'png' #one of png, svg, jpeg, webp
|
| 24 |
+
EXPORT_SCALE = 1.0 #When 1.0, the figure is exported as same size as currently in the browser. Use e.g. 0.5 to scale to half.
|
datasets/test_set.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
global_data.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import chess.engine
|
| 2 |
+
from constants import ROOT_DIR, CONTENT_HEIGHT, LEFT_PANE_WIDTH, EXPORT_FORMAT, EXPORT_SCALE
|
| 3 |
+
from time import sleep
|
| 4 |
+
# from test_array import activations_array
|
| 5 |
+
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
|
| 8 |
+
from board2planes import board2planes
|
| 9 |
+
|
| 10 |
+
import yaml
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
from os.path import isdir, join
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
SIMULATE_TF = False #TODO: Remove this option, deprecated
|
| 20 |
+
# turn off tensorflow importing and generate random data to speed up development
|
| 21 |
+
DEV_MODE = False
|
| 22 |
+
SIMULATED_LAYERS = 6
|
| 23 |
+
SIMULATED_HEADS = 64
|
| 24 |
+
FIXED_ROW = None # 1 #None to disable
|
| 25 |
+
FIXED_COL = None # 5 #None to disable
|
| 26 |
+
if DEV_MODE:
|
| 27 |
+
class DummyModel:
|
| 28 |
+
def __init__(self, layers, heads):
|
| 29 |
+
self.layers = layers
|
| 30 |
+
self.heads = heads
|
| 31 |
+
|
| 32 |
+
def __call__(self, *args, **kwargs):
|
| 33 |
+
data = [np.random.rand(1, self.heads, 64, 64) for i in range(self.layers)]
|
| 34 |
+
return [None, None, None, data]
|
| 35 |
+
|
| 36 |
+
else:
|
| 37 |
+
import tensorflow as tf
|
| 38 |
+
from tensorflow.compat.v1 import ConfigProto
|
| 39 |
+
from tensorflow.compat.v1 import InteractiveSession
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# class to hold data, state and configurations
|
| 43 |
+
# Dash is stateless and in general it is very bad idea to store data in global variables on server side
|
| 44 |
+
# However, this application is ment to be run by single user on local machine, so it is safe to store data and state
|
| 45 |
+
# information on global object
|
| 46 |
+
class GlobalData:
|
| 47 |
+
def __init__(self):
|
| 48 |
+
import os
|
| 49 |
+
if not DEV_MODE:
|
| 50 |
+
# import tensorflow as tf
|
| 51 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 52 |
+
|
| 53 |
+
# from tensorflow.compat.v1 import ConfigProto
|
| 54 |
+
# from tensorflow.compat.v1 import InteractiveSession
|
| 55 |
+
# import chess
|
| 56 |
+
# import matplotlib.patheffects as path_effects
|
| 57 |
+
|
| 58 |
+
#config = ConfigProto()
|
| 59 |
+
#config.gpu_options.allow_growth = True
|
| 60 |
+
#session = InteractiveSession(config=config)
|
| 61 |
+
#tf.keras.backend.clear_session()
|
| 62 |
+
|
| 63 |
+
self.tmp = 0
|
| 64 |
+
self.export_format = EXPORT_FORMAT
|
| 65 |
+
self.export_scale = EXPORT_SCALE
|
| 66 |
+
self.fen = 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1' # '2kr3r/ppp2b2/2n4p/4p3/Q2Pq1pP/2P1N3/PP3PP1/R1B1KB1R w KQ - 3 18'#'6n1/1p1k4/3p4/pNp5/P1P4p/7P/1P4KP/r7 w - - 2 121'#
|
| 67 |
+
self.board = chess.Board(fen=self.fen)
|
| 68 |
+
self.focused_square_ind = 0
|
| 69 |
+
self.active_move_table_cell = None # tuple (row_ind, col_id), e.g. (12, 'White')
|
| 70 |
+
|
| 71 |
+
self.activations = None # activations_array
|
| 72 |
+
self.visualization_mode = 'ROW'
|
| 73 |
+
self.visualization_mode_is_64x64 = False
|
| 74 |
+
self.subplot_mode = 'big' #'fit' # big'#'fit'#, 'big'
|
| 75 |
+
self.subplot_cols = 0
|
| 76 |
+
self.subplot_rows = 0
|
| 77 |
+
self.number_of_heads = 0
|
| 78 |
+
self.selected_head = None
|
| 79 |
+
self.show_all_heads = True
|
| 80 |
+
|
| 81 |
+
self.show_colorscale = False
|
| 82 |
+
self.colorscale_mode = 'mode1'
|
| 83 |
+
|
| 84 |
+
self.figure_container_height = '100%' # '100%'
|
| 85 |
+
|
| 86 |
+
self.running_counter = 0 # used to pass new values to hidden indicator elements which will trigger follow-up callback
|
| 87 |
+
self.grid_has_changed = False
|
| 88 |
+
|
| 89 |
+
# self.has_subplot_grid_changed = True
|
| 90 |
+
# self.figure_layout_images = None #store layout and only recalculate when subplot grid has changed
|
| 91 |
+
# self.figure_layout_annotations = None
|
| 92 |
+
# self.need_update_axis = True
|
| 93 |
+
|
| 94 |
+
self.screen_w = 0
|
| 95 |
+
self.screen_h = 0
|
| 96 |
+
self.figure_w = 0
|
| 97 |
+
self.figure_h = 0
|
| 98 |
+
self.heatmap_w = 0
|
| 99 |
+
self.heatmap_h = 0
|
| 100 |
+
self.heatmap_fig_w = 0
|
| 101 |
+
self.heatmap_fig_h = 0
|
| 102 |
+
self.heatmap_gap = 0
|
| 103 |
+
self.colorscale_x_offset = 0
|
| 104 |
+
|
| 105 |
+
self.heatmap_horizontal_gap = 0.275
|
| 106 |
+
|
| 107 |
+
self.figure_cache = {}
|
| 108 |
+
|
| 109 |
+
self.update_grid_shape()
|
| 110 |
+
|
| 111 |
+
self.pgn_data = [] # list of boards in pgn
|
| 112 |
+
self.move_table_boards = {} # dict of boards in pgn, key is (move_table.row_ind, move_table.column_id)
|
| 113 |
+
|
| 114 |
+
if not SIMULATE_TF:
|
| 115 |
+
self.selected_layer = None
|
| 116 |
+
else:
|
| 117 |
+
self.selected_layer = 0
|
| 118 |
+
|
| 119 |
+
self.nr_of_layers_in_body = -1
|
| 120 |
+
self.has_attention_policy = False
|
| 121 |
+
|
| 122 |
+
self.model_paths = []
|
| 123 |
+
self.model_names = []
|
| 124 |
+
self.model_yamls = {} #key = model path, value = yaml of that model
|
| 125 |
+
self.model_cache = {}
|
| 126 |
+
self.find_models2()
|
| 127 |
+
self.model_path = None#self.model_paths[0] # '/home/jusufe/PycharmProjects/lc0-attention-visualizer/T12_saved_model_1M'
|
| 128 |
+
self.model = None
|
| 129 |
+
self.tfp = None #TensorflowProcess
|
| 130 |
+
if not SIMULATE_TF:
|
| 131 |
+
self.load_model()
|
| 132 |
+
self.activations_data = None
|
| 133 |
+
|
| 134 |
+
if self.model is not None or SIMULATE_TF:
|
| 135 |
+
self.update_activations_data()
|
| 136 |
+
|
| 137 |
+
if self.selected_layer is not None:
|
| 138 |
+
self.set_layer(self.selected_layer)
|
| 139 |
+
|
| 140 |
+
self.move_table_active_cell = None
|
| 141 |
+
|
| 142 |
+
self.force_update_graph = False
|
| 143 |
+
|
| 144 |
+
def set_subplot_mode(self, fit_to_page):
|
| 145 |
+
if fit_to_page == [True]:
|
| 146 |
+
self.subplot_mode = 'fit'
|
| 147 |
+
else:
|
| 148 |
+
self.subplot_mode = 'big'
|
| 149 |
+
self.update_grid_shape()
|
| 150 |
+
|
| 151 |
+
def set_screen_size(self, w, h):
|
| 152 |
+
self.screen_w = w
|
| 153 |
+
self.screen_h = h
|
| 154 |
+
|
| 155 |
+
self.figure_w = w*LEFT_PANE_WIDTH/100
|
| 156 |
+
self.figure_h = h*CONTENT_HEIGHT/100
|
| 157 |
+
print('GRAPH AREA', self.figure_w, self.figure_h)
|
| 158 |
+
|
| 159 |
+
def set_heatmap_size(self, size):
|
| 160 |
+
if size != '1':
|
| 161 |
+
#print('-----------------------HEATMAP SIZE', size)
|
| 162 |
+
# w, h = size
|
| 163 |
+
# print('TYETETETETEU', global_data.screen_w)
|
| 164 |
+
# global_data.set_screen_size(w, h)
|
| 165 |
+
#print('>>>>>: HEATMAP WIDTH', size[0])
|
| 166 |
+
#print('>>>>>: HEATMAP HEIGHT', size[1])
|
| 167 |
+
#print('>>>>>: FIG WIDTH', size[2])
|
| 168 |
+
#print('>>>>>: FIG HEIGHT', size[3])
|
| 169 |
+
#print('>>>>>: HEATMAP GAP', size[4])
|
| 170 |
+
|
| 171 |
+
self.heatmap_w = float(size[0])
|
| 172 |
+
self.heatmap_h = float(size[1])
|
| 173 |
+
self.heatmap_fig_w = float(size[2])
|
| 174 |
+
self.heatmap_fig_h = float(size[3])
|
| 175 |
+
self.heatmap_gap = round(float(size[4]), 2)
|
| 176 |
+
|
| 177 |
+
self.colorscale_x_offset = float(size[5])/self.heatmap_fig_w
|
| 178 |
+
|
| 179 |
+
if size[6] == 1:
|
| 180 |
+
self.force_update_graph = True
|
| 181 |
+
else:
|
| 182 |
+
self.force_update_graph = False
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
#if self.heatmap_gap < 30:
|
| 186 |
+
# self.heatmap_horizontal_gap += 0.025
|
| 187 |
+
|
| 188 |
+
# self.heatmap_horizontal_gap = min(0.25, self.heatmap_horizontal_gap)
|
| 189 |
+
#if self.heatmap_gap < 200:
|
| 190 |
+
# self.heatmap_horizontal_gap += -0.025
|
| 191 |
+
# self.heatmap_horizontal_gap = max(0.1, self.heatmap_horizontal_gap)
|
| 192 |
+
|
| 193 |
+
def set_colorscale_mode(self, mode, colorscale_mode, colorscale_mode_64x64, show):
|
| 194 |
+
if mode == '64x64':
|
| 195 |
+
self.colorscale_mode = colorscale_mode_64x64
|
| 196 |
+
else:
|
| 197 |
+
self.colorscale_mode = colorscale_mode
|
| 198 |
+
#print('SHOW value', show)
|
| 199 |
+
self.show_colorscale = show == [True]
|
| 200 |
+
|
| 201 |
+
def cache_figure(self, fig):
|
| 202 |
+
if not self.check_if_figure_is_cached() and fig != {}:
|
| 203 |
+
key = self.get_figure_cache_key()
|
| 204 |
+
cached_fig = deepcopy(fig)
|
| 205 |
+
cached_fig.update_layout({'coloraxis1': None}, overwrite=True)
|
| 206 |
+
#print('CACHING FIGURE:')
|
| 207 |
+
self.figure_cache[key] = cached_fig
|
| 208 |
+
|
| 209 |
+
def get_cached_figure(self):
|
| 210 |
+
if self.check_if_figure_is_cached():
|
| 211 |
+
key = self.get_figure_cache_key()
|
| 212 |
+
fig = deepcopy(self.figure_cache[key])
|
| 213 |
+
else:
|
| 214 |
+
fig = None
|
| 215 |
+
return fig
|
| 216 |
+
|
| 217 |
+
def check_if_figure_is_cached(self):
|
| 218 |
+
key = self.get_figure_cache_key()
|
| 219 |
+
return key in self.figure_cache
|
| 220 |
+
|
| 221 |
+
def get_figure_cache_key(self):
|
| 222 |
+
return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64,
|
| 223 |
+
self.selected_head if not self.show_all_heads else -1, self.show_colorscale, self.colorscale_mode,
|
| 224 |
+
self.board.turn)
|
| 225 |
+
#return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64, self.selected_head if not self.show_all_heads else -1, self.heatmap_horizontal_gap, self.heatmap_fig_h, self.heatmap_fig_w)
|
| 226 |
+
#return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64, self.show_all_heads)
|
| 227 |
+
|
| 228 |
+
def get_side_to_move(self):
|
| 229 |
+
return ['Black', 'White'][self.board.turn]
|
| 230 |
+
|
| 231 |
+
def load_model(self):
|
| 232 |
+
if self.model_path in self.model_cache:
|
| 233 |
+
self.model, self.tfp = self.model_cache[self.model_path]
|
| 234 |
+
|
| 235 |
+
elif self.model_path is not None:
|
| 236 |
+
#net = '/home/jusufe/Projects/lc0/BT1024-3142c-swa-186000.pb.gz'
|
| 237 |
+
#yaml_path = '/home/jusufe/Downloads/cfg.yaml'
|
| 238 |
+
if not DEV_MODE:
|
| 239 |
+
net = self.model_path
|
| 240 |
+
yaml_path = self.model_yamls[self.model_path]
|
| 241 |
+
with open(yaml_path) as f:
|
| 242 |
+
cfg = f.read()
|
| 243 |
+
cfg = yaml.safe_load(cfg)
|
| 244 |
+
|
| 245 |
+
if 'dropout_rate' in cfg['model']:
|
| 246 |
+
print('Setting dropout_rate to 0.0')
|
| 247 |
+
cfg['model']['dropout_rate'] = 0.0
|
| 248 |
+
|
| 249 |
+
tfp = tfprocess.TFProcess(cfg)
|
| 250 |
+
tfp.init_net()
|
| 251 |
+
tfp.replace_weights(net, ignore_errors=True)
|
| 252 |
+
self.model = tfp.model
|
| 253 |
+
self.tfp = tfp
|
| 254 |
+
else:
|
| 255 |
+
self.model = DummyModel(SIMULATED_LAYERS, SIMULATED_HEADS)
|
| 256 |
+
self.tfp = None
|
| 257 |
+
|
| 258 |
+
else:
|
| 259 |
+
self.model = None
|
| 260 |
+
self.tfp = None
|
| 261 |
+
|
| 262 |
+
def find_models(self):
|
| 263 |
+
root = ROOT_DIR
|
| 264 |
+
models_root_folder = os.path.join(root, 'models')
|
| 265 |
+
model_folders = [f for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))]
|
| 266 |
+
model_paths = [os.path.relpath(join(models_root_folder, f)) for f in os.listdir(models_root_folder) if
|
| 267 |
+
isdir(join(models_root_folder, f))]
|
| 268 |
+
self.model_names = model_folders
|
| 269 |
+
self.model_paths = model_paths
|
| 270 |
+
|
| 271 |
+
#print('MODELS:')
|
| 272 |
+
#print(self.model_names)
|
| 273 |
+
#print(self.model_paths)
|
| 274 |
+
|
| 275 |
+
def find_models2(self):
|
| 276 |
+
import os
|
| 277 |
+
from os.path import isdir, join
|
| 278 |
+
root = ROOT_DIR
|
| 279 |
+
models_root_folder = os.path.join(root, 'models')
|
| 280 |
+
model_folders = [f for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))]
|
| 281 |
+
model_paths = [os.path.relpath(join(models_root_folder, f)) for f in os.listdir(models_root_folder) if
|
| 282 |
+
isdir(join(models_root_folder, f))]
|
| 283 |
+
|
| 284 |
+
models = []
|
| 285 |
+
paths = []
|
| 286 |
+
yamls = []
|
| 287 |
+
for path in model_paths:
|
| 288 |
+
yaml_files = [file for file in os.listdir(path) if file.endswith(".yaml")]
|
| 289 |
+
if len(yaml_files) != 1:
|
| 290 |
+
continue
|
| 291 |
+
model_files = [file for file in os.listdir(path) if file.endswith(".pb.gz")]
|
| 292 |
+
if len(model_files) == 0:
|
| 293 |
+
continue
|
| 294 |
+
|
| 295 |
+
models += model_files
|
| 296 |
+
paths += [os.path.relpath(join(path, f)) for f in model_files]
|
| 297 |
+
yaml_file = os.path.relpath(join(path, yaml_files[0]))
|
| 298 |
+
yamls += [yaml_file]*len(model_files)
|
| 299 |
+
|
| 300 |
+
self.model_yamls = {path: yaml_file for path, yaml_file in zip(paths, yamls)}
|
| 301 |
+
self.model_names = models
|
| 302 |
+
self.model_paths = paths#model_paths
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def update_activations_data(self):
|
| 306 |
+
|
| 307 |
+
if self.model is not None and self.selected_layer is None:
|
| 308 |
+
self.selected_layer = 0
|
| 309 |
+
|
| 310 |
+
if not SIMULATE_TF:
|
| 311 |
+
if self.selected_layer is not None and self.model is not None and self.selected_layer != 'Smolgen':
|
| 312 |
+
if not DEV_MODE:
|
| 313 |
+
inputs = board2planes(self.board)
|
| 314 |
+
inputs = tf.reshape(tf.convert_to_tensor(inputs, dtype=tf.float32), [-1, 112, 8, 8])
|
| 315 |
+
else:
|
| 316 |
+
inputs = None
|
| 317 |
+
|
| 318 |
+
outputs = self.model(inputs)
|
| 319 |
+
self.activations_data = outputs[-1]
|
| 320 |
+
for i,x in enumerate(self.activations_data):
|
| 321 |
+
print( 'LAYERS', i, x.shape)
|
| 322 |
+
|
| 323 |
+
#smolgen = self.tfp.smol_weight_gen_dense.get_weights()[0].reshape((256, 64, 64))
|
| 324 |
+
#print('Smolgen')
|
| 325 |
+
#print(type(smolgen))
|
| 326 |
+
#print(smolgen.shape)
|
| 327 |
+
#print(type(smolgen[0]))
|
| 328 |
+
#print(smolgen[0].shape)
|
| 329 |
+
#_, _, _, self.activations_data = self.model(inputs)
|
| 330 |
+
elif self.selected_layer == 'Smolgen' and self.tfp is not None and self.tfp.use_smolgen:
|
| 331 |
+
weights = self.tfp.smol_weight_gen_dense.get_weights()[0]
|
| 332 |
+
self.activations_data = weights.reshape((weights.shape[0], 64, 64))
|
| 333 |
+
print('TYPEEEEE', type(self.activations_data))
|
| 334 |
+
|
| 335 |
+
else:
|
| 336 |
+
layers = SIMULATED_LAYERS
|
| 337 |
+
heads = SIMULATED_HEADS
|
| 338 |
+
self.activations_data = [np.random.rand(1, heads, 64, 64) for i in range(layers)]
|
| 339 |
+
|
| 340 |
+
if self.model is not None:
|
| 341 |
+
|
| 342 |
+
if self.model_path not in self.model_cache:
|
| 343 |
+
self.model_cache[self.model_path] = [self.model, self.tfp]
|
| 344 |
+
|
| 345 |
+
self.update_layers_in_body_count()
|
| 346 |
+
|
| 347 |
+
#TODO: figure out better way to determine if we have policy attention weights
|
| 348 |
+
#TODO: What happens if policy vis is selected and user switches to model without policy layer? Take care of this case.
|
| 349 |
+
if self.activations_data is not None and self.activations_data[-2].shape == (1, 8, 24):
|
| 350 |
+
self.has_attention_policy = True
|
| 351 |
+
else:
|
| 352 |
+
self.has_attention_policy = False
|
| 353 |
+
# self.update_selected_activation_data()
|
| 354 |
+
# self.activations = self.activations_data[self.selected_layer]
|
| 355 |
+
|
| 356 |
+
def update_grid_shape(self):
|
| 357 |
+
# TODO: add client side callback triggered by Interval component to save window or precise container dimensions to Div
|
| 358 |
+
# TODO: Trigger server side figure update callback when dimensions are recorded and store in global_data
|
| 359 |
+
# TODO: If needed, recalculate subplot rows and cols and container scaler based on the changed dimension
|
| 360 |
+
|
| 361 |
+
def calc_cols(heads, rows):
|
| 362 |
+
if heads % rows == 0:
|
| 363 |
+
cols = int(heads / rows)
|
| 364 |
+
else:
|
| 365 |
+
cols = int(1 + heads / rows)
|
| 366 |
+
return cols
|
| 367 |
+
|
| 368 |
+
if FIXED_ROW and FIXED_COL:
|
| 369 |
+
self.subplot_cols = FIXED_COL
|
| 370 |
+
self.subplot_rows = FIXED_ROW
|
| 371 |
+
return None
|
| 372 |
+
|
| 373 |
+
heads = self.number_of_heads
|
| 374 |
+
if self.subplot_mode == 'fit':
|
| 375 |
+
max_rows_in_screen = 4
|
| 376 |
+
if heads <= 4:
|
| 377 |
+
rows = 1
|
| 378 |
+
elif heads <= 8:
|
| 379 |
+
rows = 2
|
| 380 |
+
else:
|
| 381 |
+
rows = heads // 8 + int(heads % 8 != 0)
|
| 382 |
+
|
| 383 |
+
elif self.subplot_mode == 'big':
|
| 384 |
+
#print(heads)
|
| 385 |
+
|
| 386 |
+
max_rows_in_screen = 2
|
| 387 |
+
rows = heads // 4 + int(heads % 4 != 0)
|
| 388 |
+
#print(rows)
|
| 389 |
+
|
| 390 |
+
if rows > max_rows_in_screen:
|
| 391 |
+
container_height = f'{int((rows / max_rows_in_screen) * 100)}%'
|
| 392 |
+
else:
|
| 393 |
+
container_height = '100%'
|
| 394 |
+
|
| 395 |
+
if rows != 0:
|
| 396 |
+
cols = calc_cols(heads, rows)
|
| 397 |
+
else:
|
| 398 |
+
cols = 0
|
| 399 |
+
|
| 400 |
+
if self.subplot_rows != rows or self.subplot_cols != cols:
|
| 401 |
+
self.grid_has_changed = True
|
| 402 |
+
|
| 403 |
+
self.subplot_cols = cols
|
| 404 |
+
self.subplot_rows = rows
|
| 405 |
+
|
| 406 |
+
if self.show_all_heads:
|
| 407 |
+
self.figure_container_height = container_height
|
| 408 |
+
else:
|
| 409 |
+
self.figure_container_height = '100%'
|
| 410 |
+
|
| 411 |
+
def update_selected_activation_data(self):
|
| 412 |
+
# import numpy as np
|
| 413 |
+
# self.activations = activations_array + np.random.rand(8, 64, 64)
|
| 414 |
+
if self.activations_data is not None:
|
| 415 |
+
if self.selected_layer not in ('Policy', 'Smolgen'):
|
| 416 |
+
if not DEV_MODE:
|
| 417 |
+
activations = tf.squeeze(self.activations_data[self.selected_layer], axis=0).numpy()
|
| 418 |
+
#self.activations = activations[:, ::-1, :] #Flip along y-axis
|
| 419 |
+
else:
|
| 420 |
+
activations = np.squeeze(self.activations_data[self.selected_layer], axis=0)
|
| 421 |
+
elif self.selected_layer == 'Policy':
|
| 422 |
+
print('RAW POLICY SHAPE', self.activations_data[-1].shape)
|
| 423 |
+
activations = self.activations_data[-1].numpy()
|
| 424 |
+
#print('POLICY SHAPE', activations.shape)
|
| 425 |
+
|
| 426 |
+
#print('RAW POLICY SHAPE', self.activations_data[-1].shape)
|
| 427 |
+
#activations = np.squeeze(self.activations_data[-1].numpy(), axis=0) #shape 64,64
|
| 428 |
+
#promo = np.squeeze(self.activations_data[-2].numpy(), axis=0) #shape 8,24
|
| 429 |
+
#print('promo shape:', promo.shape)
|
| 430 |
+
#if self.board.turn:
|
| 431 |
+
# pad_shape = (48, 8)
|
| 432 |
+
#else:
|
| 433 |
+
# pad_shape = (8, 48)
|
| 434 |
+
#promo_padded = np.pad(promo, (pad_shape, (0, 0)), mode='constant', constant_values=None) #shape 64,24
|
| 435 |
+
#self.activations = np.expand_dims(np.concatenate((activations, promo_padded), axis=1), axis=0)#shape 1,64,88
|
| 436 |
+
#print('POLICY SHAPE', self.activations.shape)
|
| 437 |
+
elif self.selected_layer == 'Smolgen':
|
| 438 |
+
activations = self.tfp.smol_weight_gen_dense.get_weights()[0].reshape((256, 64, 64))
|
| 439 |
+
|
| 440 |
+
self.activations = activations[:, ::-1, :] # Flip along y-axis
|
| 441 |
+
|
| 442 |
+
def set_visualization_mode(self, mode):
|
| 443 |
+
self.visualization_mode = mode
|
| 444 |
+
self.visualization_mode_is_64x64 = mode == '64x64'
|
| 445 |
+
|
| 446 |
+
def set_layer(self, layer):
|
| 447 |
+
self.selected_layer = layer
|
| 448 |
+
self.update_selected_activation_data()
|
| 449 |
+
if layer not in ('Policy', 'Smolgen'):
|
| 450 |
+
self.number_of_heads = self.activations_data[self.selected_layer].shape[1]
|
| 451 |
+
elif layer == 'Policy':
|
| 452 |
+
self.number_of_heads = 1
|
| 453 |
+
elif layer == 'Smolgen':
|
| 454 |
+
self.number_of_heads = self.activations.shape[0]
|
| 455 |
+
self.set_head(0)
|
| 456 |
+
self.update_grid_shape()
|
| 457 |
+
|
| 458 |
+
def set_head(self, head):
|
| 459 |
+
self.selected_head = head
|
| 460 |
+
|
| 461 |
+
def set_model(self, model):
|
| 462 |
+
if model != self.model_path:
|
| 463 |
+
self.model_path = model
|
| 464 |
+
self.load_model()
|
| 465 |
+
self.update_activations_data()
|
| 466 |
+
self.update_selected_activation_data()
|
| 467 |
+
self.number_of_heads = self.activations_data[self.selected_layer].shape[1]
|
| 468 |
+
if self.selected_head is None:
|
| 469 |
+
self.selected_head = 0
|
| 470 |
+
else:
|
| 471 |
+
self.selected_head = min(self.selected_head, self.number_of_heads - 1)
|
| 472 |
+
self.update_grid_shape()
|
| 473 |
+
if SIMULATE_TF:
|
| 474 |
+
sleep(2)
|
| 475 |
+
|
| 476 |
+
def update_layers_in_body_count(self):
|
| 477 |
+
# TODO: figure out robust way to separate attention layers in body from the rest. UPDATE: Use yaml
|
| 478 |
+
heads = self.activations_data[0].shape[1]
|
| 479 |
+
for ind, layer in enumerate(self.activations_data):
|
| 480 |
+
if layer.shape[1] != heads or len(layer.shape) != 4:
|
| 481 |
+
ind = ind - 1
|
| 482 |
+
break
|
| 483 |
+
self.nr_of_layers_in_body = ind + 1
|
| 484 |
+
if self.selected_layer not in ('Policy', 'Smolgen'):
|
| 485 |
+
self.selected_layer = min(self.selected_layer, self.nr_of_layers_in_body - 1)
|
| 486 |
+
|
| 487 |
+
def get_head_data(self, head):
|
| 488 |
+
|
| 489 |
+
if self.activations.shape[0] <= head:
|
| 490 |
+
return None
|
| 491 |
+
|
| 492 |
+
if self.visualization_mode == '64x64':
|
| 493 |
+
# print('64x64 selection')
|
| 494 |
+
data = self.activations[head, :, :]
|
| 495 |
+
|
| 496 |
+
elif self.visualization_mode == 'ROW':
|
| 497 |
+
# print('ROW selection')
|
| 498 |
+
if self.board.turn or self.selected_layer == 'Smolgen': #White turn to move
|
| 499 |
+
row = 63 - self.focused_square_ind
|
| 500 |
+
data = self.activations[head, row, :].reshape((8, 8))
|
| 501 |
+
else:
|
| 502 |
+
#row = self.focused_square_ind
|
| 503 |
+
multiples = self.focused_square_ind // 8
|
| 504 |
+
remainder = self.focused_square_ind % 8
|
| 505 |
+
|
| 506 |
+
a = 7 - remainder
|
| 507 |
+
b = multiples * 8
|
| 508 |
+
row = a + b
|
| 509 |
+
data = self.activations[head, row, :].reshape((8, 8))[::-1, :]
|
| 510 |
+
else:
|
| 511 |
+
# print('COL selection')
|
| 512 |
+
if self.board.turn or self.selected_layer == 'Smolgen': #White turn to move
|
| 513 |
+
col = self.focused_square_ind
|
| 514 |
+
data = self.activations[head, :, col].reshape((8, 8))[::-1, ::-1]
|
| 515 |
+
else:
|
| 516 |
+
focused = 63 - self.focused_square_ind
|
| 517 |
+
multiples = focused // 8
|
| 518 |
+
remainder = focused % 8
|
| 519 |
+
a = 7 - remainder
|
| 520 |
+
b = multiples * 8
|
| 521 |
+
col = a + b
|
| 522 |
+
#print('COL!!!!!!!!!!!!!!!!!', col, a, b, focused, self.focused_square_ind)
|
| 523 |
+
data = self.activations[head, :, col].reshape((8, 8))[:, ::-1]
|
| 524 |
+
return data
|
| 525 |
+
|
| 526 |
+
def set_fen(self, fen):
|
| 527 |
+
self.board.set_fen(fen)
|
| 528 |
+
self.fen = fen
|
| 529 |
+
self.update_activations_data()
|
| 530 |
+
self.update_selected_activation_data()
|
| 531 |
+
|
| 532 |
+
def set_board(self, board):
|
| 533 |
+
self.board = deepcopy(board)
|
| 534 |
+
self.update_activations_data()
|
| 535 |
+
self.update_selected_activation_data()
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
global_data = GlobalData()
|
| 539 |
+
print('global data created')
|
layouts/visualization_demo.slides.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"type": "slides",
|
| 3 |
+
"data": {}
|
| 4 |
+
}
|
leela_board.py
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# uci_to_idx is a list of four dicts {uci -> NN policy index}
|
| 4 |
+
# 0 = white, no-castling
|
| 5 |
+
# 1 = white, castling
|
| 6 |
+
# 2 = black, no-castling
|
| 7 |
+
# 3 = black, castling
|
| 8 |
+
# Black moves are flipped, and castling moves are mapped to the e8a8, e8h8, e1a1, e1h1 indexes
|
| 9 |
+
# from their respective UCI names
|
| 10 |
+
|
| 11 |
+
uci_to_idx = []
|
| 12 |
+
|
| 13 |
+
# The index-to-uci list originates from here:
|
| 14 |
+
# https://github.com/glinscott/leela-chess/blob/master/lc0/src/chess/bitboard.cc
|
| 15 |
+
|
| 16 |
+
# White, no-castling
|
| 17 |
+
_idx_to_move_wn = [
|
| 18 |
+
'a1b1', 'a1c1', 'a1d1', 'a1e1', 'a1f1', 'a1g1', 'a1h1',
|
| 19 |
+
'a1a2', 'a1b2', 'a1c2', 'a1a3', 'a1b3', 'a1c3', 'a1a4',
|
| 20 |
+
'a1d4', 'a1a5', 'a1e5', 'a1a6', 'a1f6', 'a1a7', 'a1g7',
|
| 21 |
+
'a1a8', 'a1h8', 'b1a1', 'b1c1', 'b1d1', 'b1e1', 'b1f1',
|
| 22 |
+
'b1g1', 'b1h1', 'b1a2', 'b1b2', 'b1c2', 'b1d2', 'b1a3',
|
| 23 |
+
'b1b3', 'b1c3', 'b1d3', 'b1b4', 'b1e4', 'b1b5', 'b1f5',
|
| 24 |
+
'b1b6', 'b1g6', 'b1b7', 'b1h7', 'b1b8', 'c1a1', 'c1b1',
|
| 25 |
+
'c1d1', 'c1e1', 'c1f1', 'c1g1', 'c1h1', 'c1a2', 'c1b2',
|
| 26 |
+
'c1c2', 'c1d2', 'c1e2', 'c1a3', 'c1b3', 'c1c3', 'c1d3',
|
| 27 |
+
'c1e3', 'c1c4', 'c1f4', 'c1c5', 'c1g5', 'c1c6', 'c1h6',
|
| 28 |
+
'c1c7', 'c1c8', 'd1a1', 'd1b1', 'd1c1', 'd1e1', 'd1f1',
|
| 29 |
+
'd1g1', 'd1h1', 'd1b2', 'd1c2', 'd1d2', 'd1e2', 'd1f2',
|
| 30 |
+
'd1b3', 'd1c3', 'd1d3', 'd1e3', 'd1f3', 'd1a4', 'd1d4',
|
| 31 |
+
'd1g4', 'd1d5', 'd1h5', 'd1d6', 'd1d7', 'd1d8', 'e1a1',
|
| 32 |
+
'e1b1', 'e1c1', 'e1d1', 'e1f1', 'e1g1', 'e1h1', 'e1c2',
|
| 33 |
+
'e1d2', 'e1e2', 'e1f2', 'e1g2', 'e1c3', 'e1d3', 'e1e3',
|
| 34 |
+
'e1f3', 'e1g3', 'e1b4', 'e1e4', 'e1h4', 'e1a5', 'e1e5',
|
| 35 |
+
'e1e6', 'e1e7', 'e1e8', 'f1a1', 'f1b1', 'f1c1', 'f1d1',
|
| 36 |
+
'f1e1', 'f1g1', 'f1h1', 'f1d2', 'f1e2', 'f1f2', 'f1g2',
|
| 37 |
+
'f1h2', 'f1d3', 'f1e3', 'f1f3', 'f1g3', 'f1h3', 'f1c4',
|
| 38 |
+
'f1f4', 'f1b5', 'f1f5', 'f1a6', 'f1f6', 'f1f7', 'f1f8',
|
| 39 |
+
'g1a1', 'g1b1', 'g1c1', 'g1d1', 'g1e1', 'g1f1', 'g1h1',
|
| 40 |
+
'g1e2', 'g1f2', 'g1g2', 'g1h2', 'g1e3', 'g1f3', 'g1g3',
|
| 41 |
+
'g1h3', 'g1d4', 'g1g4', 'g1c5', 'g1g5', 'g1b6', 'g1g6',
|
| 42 |
+
'g1a7', 'g1g7', 'g1g8', 'h1a1', 'h1b1', 'h1c1', 'h1d1',
|
| 43 |
+
'h1e1', 'h1f1', 'h1g1', 'h1f2', 'h1g2', 'h1h2', 'h1f3',
|
| 44 |
+
'h1g3', 'h1h3', 'h1e4', 'h1h4', 'h1d5', 'h1h5', 'h1c6',
|
| 45 |
+
'h1h6', 'h1b7', 'h1h7', 'h1a8', 'h1h8', 'a2a1', 'a2b1',
|
| 46 |
+
'a2c1', 'a2b2', 'a2c2', 'a2d2', 'a2e2', 'a2f2', 'a2g2',
|
| 47 |
+
'a2h2', 'a2a3', 'a2b3', 'a2c3', 'a2a4', 'a2b4', 'a2c4',
|
| 48 |
+
'a2a5', 'a2d5', 'a2a6', 'a2e6', 'a2a7', 'a2f7', 'a2a8',
|
| 49 |
+
'a2g8', 'b2a1', 'b2b1', 'b2c1', 'b2d1', 'b2a2', 'b2c2',
|
| 50 |
+
'b2d2', 'b2e2', 'b2f2', 'b2g2', 'b2h2', 'b2a3', 'b2b3',
|
| 51 |
+
'b2c3', 'b2d3', 'b2a4', 'b2b4', 'b2c4', 'b2d4', 'b2b5',
|
| 52 |
+
'b2e5', 'b2b6', 'b2f6', 'b2b7', 'b2g7', 'b2b8', 'b2h8',
|
| 53 |
+
'c2a1', 'c2b1', 'c2c1', 'c2d1', 'c2e1', 'c2a2', 'c2b2',
|
| 54 |
+
'c2d2', 'c2e2', 'c2f2', 'c2g2', 'c2h2', 'c2a3', 'c2b3',
|
| 55 |
+
'c2c3', 'c2d3', 'c2e3', 'c2a4', 'c2b4', 'c2c4', 'c2d4',
|
| 56 |
+
'c2e4', 'c2c5', 'c2f5', 'c2c6', 'c2g6', 'c2c7', 'c2h7',
|
| 57 |
+
'c2c8', 'd2b1', 'd2c1', 'd2d1', 'd2e1', 'd2f1', 'd2a2',
|
| 58 |
+
'd2b2', 'd2c2', 'd2e2', 'd2f2', 'd2g2', 'd2h2', 'd2b3',
|
| 59 |
+
'd2c3', 'd2d3', 'd2e3', 'd2f3', 'd2b4', 'd2c4', 'd2d4',
|
| 60 |
+
'd2e4', 'd2f4', 'd2a5', 'd2d5', 'd2g5', 'd2d6', 'd2h6',
|
| 61 |
+
'd2d7', 'd2d8', 'e2c1', 'e2d1', 'e2e1', 'e2f1', 'e2g1',
|
| 62 |
+
'e2a2', 'e2b2', 'e2c2', 'e2d2', 'e2f2', 'e2g2', 'e2h2',
|
| 63 |
+
'e2c3', 'e2d3', 'e2e3', 'e2f3', 'e2g3', 'e2c4', 'e2d4',
|
| 64 |
+
'e2e4', 'e2f4', 'e2g4', 'e2b5', 'e2e5', 'e2h5', 'e2a6',
|
| 65 |
+
'e2e6', 'e2e7', 'e2e8', 'f2d1', 'f2e1', 'f2f1', 'f2g1',
|
| 66 |
+
'f2h1', 'f2a2', 'f2b2', 'f2c2', 'f2d2', 'f2e2', 'f2g2',
|
| 67 |
+
'f2h2', 'f2d3', 'f2e3', 'f2f3', 'f2g3', 'f2h3', 'f2d4',
|
| 68 |
+
'f2e4', 'f2f4', 'f2g4', 'f2h4', 'f2c5', 'f2f5', 'f2b6',
|
| 69 |
+
'f2f6', 'f2a7', 'f2f7', 'f2f8', 'g2e1', 'g2f1', 'g2g1',
|
| 70 |
+
'g2h1', 'g2a2', 'g2b2', 'g2c2', 'g2d2', 'g2e2', 'g2f2',
|
| 71 |
+
'g2h2', 'g2e3', 'g2f3', 'g2g3', 'g2h3', 'g2e4', 'g2f4',
|
| 72 |
+
'g2g4', 'g2h4', 'g2d5', 'g2g5', 'g2c6', 'g2g6', 'g2b7',
|
| 73 |
+
'g2g7', 'g2a8', 'g2g8', 'h2f1', 'h2g1', 'h2h1', 'h2a2',
|
| 74 |
+
'h2b2', 'h2c2', 'h2d2', 'h2e2', 'h2f2', 'h2g2', 'h2f3',
|
| 75 |
+
'h2g3', 'h2h3', 'h2f4', 'h2g4', 'h2h4', 'h2e5', 'h2h5',
|
| 76 |
+
'h2d6', 'h2h6', 'h2c7', 'h2h7', 'h2b8', 'h2h8', 'a3a1',
|
| 77 |
+
'a3b1', 'a3c1', 'a3a2', 'a3b2', 'a3c2', 'a3b3', 'a3c3',
|
| 78 |
+
'a3d3', 'a3e3', 'a3f3', 'a3g3', 'a3h3', 'a3a4', 'a3b4',
|
| 79 |
+
'a3c4', 'a3a5', 'a3b5', 'a3c5', 'a3a6', 'a3d6', 'a3a7',
|
| 80 |
+
'a3e7', 'a3a8', 'a3f8', 'b3a1', 'b3b1', 'b3c1', 'b3d1',
|
| 81 |
+
'b3a2', 'b3b2', 'b3c2', 'b3d2', 'b3a3', 'b3c3', 'b3d3',
|
| 82 |
+
'b3e3', 'b3f3', 'b3g3', 'b3h3', 'b3a4', 'b3b4', 'b3c4',
|
| 83 |
+
'b3d4', 'b3a5', 'b3b5', 'b3c5', 'b3d5', 'b3b6', 'b3e6',
|
| 84 |
+
'b3b7', 'b3f7', 'b3b8', 'b3g8', 'c3a1', 'c3b1', 'c3c1',
|
| 85 |
+
'c3d1', 'c3e1', 'c3a2', 'c3b2', 'c3c2', 'c3d2', 'c3e2',
|
| 86 |
+
'c3a3', 'c3b3', 'c3d3', 'c3e3', 'c3f3', 'c3g3', 'c3h3',
|
| 87 |
+
'c3a4', 'c3b4', 'c3c4', 'c3d4', 'c3e4', 'c3a5', 'c3b5',
|
| 88 |
+
'c3c5', 'c3d5', 'c3e5', 'c3c6', 'c3f6', 'c3c7', 'c3g7',
|
| 89 |
+
'c3c8', 'c3h8', 'd3b1', 'd3c1', 'd3d1', 'd3e1', 'd3f1',
|
| 90 |
+
'd3b2', 'd3c2', 'd3d2', 'd3e2', 'd3f2', 'd3a3', 'd3b3',
|
| 91 |
+
'd3c3', 'd3e3', 'd3f3', 'd3g3', 'd3h3', 'd3b4', 'd3c4',
|
| 92 |
+
'd3d4', 'd3e4', 'd3f4', 'd3b5', 'd3c5', 'd3d5', 'd3e5',
|
| 93 |
+
'd3f5', 'd3a6', 'd3d6', 'd3g6', 'd3d7', 'd3h7', 'd3d8',
|
| 94 |
+
'e3c1', 'e3d1', 'e3e1', 'e3f1', 'e3g1', 'e3c2', 'e3d2',
|
| 95 |
+
'e3e2', 'e3f2', 'e3g2', 'e3a3', 'e3b3', 'e3c3', 'e3d3',
|
| 96 |
+
'e3f3', 'e3g3', 'e3h3', 'e3c4', 'e3d4', 'e3e4', 'e3f4',
|
| 97 |
+
'e3g4', 'e3c5', 'e3d5', 'e3e5', 'e3f5', 'e3g5', 'e3b6',
|
| 98 |
+
'e3e6', 'e3h6', 'e3a7', 'e3e7', 'e3e8', 'f3d1', 'f3e1',
|
| 99 |
+
'f3f1', 'f3g1', 'f3h1', 'f3d2', 'f3e2', 'f3f2', 'f3g2',
|
| 100 |
+
'f3h2', 'f3a3', 'f3b3', 'f3c3', 'f3d3', 'f3e3', 'f3g3',
|
| 101 |
+
'f3h3', 'f3d4', 'f3e4', 'f3f4', 'f3g4', 'f3h4', 'f3d5',
|
| 102 |
+
'f3e5', 'f3f5', 'f3g5', 'f3h5', 'f3c6', 'f3f6', 'f3b7',
|
| 103 |
+
'f3f7', 'f3a8', 'f3f8', 'g3e1', 'g3f1', 'g3g1', 'g3h1',
|
| 104 |
+
'g3e2', 'g3f2', 'g3g2', 'g3h2', 'g3a3', 'g3b3', 'g3c3',
|
| 105 |
+
'g3d3', 'g3e3', 'g3f3', 'g3h3', 'g3e4', 'g3f4', 'g3g4',
|
| 106 |
+
'g3h4', 'g3e5', 'g3f5', 'g3g5', 'g3h5', 'g3d6', 'g3g6',
|
| 107 |
+
'g3c7', 'g3g7', 'g3b8', 'g3g8', 'h3f1', 'h3g1', 'h3h1',
|
| 108 |
+
'h3f2', 'h3g2', 'h3h2', 'h3a3', 'h3b3', 'h3c3', 'h3d3',
|
| 109 |
+
'h3e3', 'h3f3', 'h3g3', 'h3f4', 'h3g4', 'h3h4', 'h3f5',
|
| 110 |
+
'h3g5', 'h3h5', 'h3e6', 'h3h6', 'h3d7', 'h3h7', 'h3c8',
|
| 111 |
+
'h3h8', 'a4a1', 'a4d1', 'a4a2', 'a4b2', 'a4c2', 'a4a3',
|
| 112 |
+
'a4b3', 'a4c3', 'a4b4', 'a4c4', 'a4d4', 'a4e4', 'a4f4',
|
| 113 |
+
'a4g4', 'a4h4', 'a4a5', 'a4b5', 'a4c5', 'a4a6', 'a4b6',
|
| 114 |
+
'a4c6', 'a4a7', 'a4d7', 'a4a8', 'a4e8', 'b4b1', 'b4e1',
|
| 115 |
+
'b4a2', 'b4b2', 'b4c2', 'b4d2', 'b4a3', 'b4b3', 'b4c3',
|
| 116 |
+
'b4d3', 'b4a4', 'b4c4', 'b4d4', 'b4e4', 'b4f4', 'b4g4',
|
| 117 |
+
'b4h4', 'b4a5', 'b4b5', 'b4c5', 'b4d5', 'b4a6', 'b4b6',
|
| 118 |
+
'b4c6', 'b4d6', 'b4b7', 'b4e7', 'b4b8', 'b4f8', 'c4c1',
|
| 119 |
+
'c4f1', 'c4a2', 'c4b2', 'c4c2', 'c4d2', 'c4e2', 'c4a3',
|
| 120 |
+
'c4b3', 'c4c3', 'c4d3', 'c4e3', 'c4a4', 'c4b4', 'c4d4',
|
| 121 |
+
'c4e4', 'c4f4', 'c4g4', 'c4h4', 'c4a5', 'c4b5', 'c4c5',
|
| 122 |
+
'c4d5', 'c4e5', 'c4a6', 'c4b6', 'c4c6', 'c4d6', 'c4e6',
|
| 123 |
+
'c4c7', 'c4f7', 'c4c8', 'c4g8', 'd4a1', 'd4d1', 'd4g1',
|
| 124 |
+
'd4b2', 'd4c2', 'd4d2', 'd4e2', 'd4f2', 'd4b3', 'd4c3',
|
| 125 |
+
'd4d3', 'd4e3', 'd4f3', 'd4a4', 'd4b4', 'd4c4', 'd4e4',
|
| 126 |
+
'd4f4', 'd4g4', 'd4h4', 'd4b5', 'd4c5', 'd4d5', 'd4e5',
|
| 127 |
+
'd4f5', 'd4b6', 'd4c6', 'd4d6', 'd4e6', 'd4f6', 'd4a7',
|
| 128 |
+
'd4d7', 'd4g7', 'd4d8', 'd4h8', 'e4b1', 'e4e1', 'e4h1',
|
| 129 |
+
'e4c2', 'e4d2', 'e4e2', 'e4f2', 'e4g2', 'e4c3', 'e4d3',
|
| 130 |
+
'e4e3', 'e4f3', 'e4g3', 'e4a4', 'e4b4', 'e4c4', 'e4d4',
|
| 131 |
+
'e4f4', 'e4g4', 'e4h4', 'e4c5', 'e4d5', 'e4e5', 'e4f5',
|
| 132 |
+
'e4g5', 'e4c6', 'e4d6', 'e4e6', 'e4f6', 'e4g6', 'e4b7',
|
| 133 |
+
'e4e7', 'e4h7', 'e4a8', 'e4e8', 'f4c1', 'f4f1', 'f4d2',
|
| 134 |
+
'f4e2', 'f4f2', 'f4g2', 'f4h2', 'f4d3', 'f4e3', 'f4f3',
|
| 135 |
+
'f4g3', 'f4h3', 'f4a4', 'f4b4', 'f4c4', 'f4d4', 'f4e4',
|
| 136 |
+
'f4g4', 'f4h4', 'f4d5', 'f4e5', 'f4f5', 'f4g5', 'f4h5',
|
| 137 |
+
'f4d6', 'f4e6', 'f4f6', 'f4g6', 'f4h6', 'f4c7', 'f4f7',
|
| 138 |
+
'f4b8', 'f4f8', 'g4d1', 'g4g1', 'g4e2', 'g4f2', 'g4g2',
|
| 139 |
+
'g4h2', 'g4e3', 'g4f3', 'g4g3', 'g4h3', 'g4a4', 'g4b4',
|
| 140 |
+
'g4c4', 'g4d4', 'g4e4', 'g4f4', 'g4h4', 'g4e5', 'g4f5',
|
| 141 |
+
'g4g5', 'g4h5', 'g4e6', 'g4f6', 'g4g6', 'g4h6', 'g4d7',
|
| 142 |
+
'g4g7', 'g4c8', 'g4g8', 'h4e1', 'h4h1', 'h4f2', 'h4g2',
|
| 143 |
+
'h4h2', 'h4f3', 'h4g3', 'h4h3', 'h4a4', 'h4b4', 'h4c4',
|
| 144 |
+
'h4d4', 'h4e4', 'h4f4', 'h4g4', 'h4f5', 'h4g5', 'h4h5',
|
| 145 |
+
'h4f6', 'h4g6', 'h4h6', 'h4e7', 'h4h7', 'h4d8', 'h4h8',
|
| 146 |
+
'a5a1', 'a5e1', 'a5a2', 'a5d2', 'a5a3', 'a5b3', 'a5c3',
|
| 147 |
+
'a5a4', 'a5b4', 'a5c4', 'a5b5', 'a5c5', 'a5d5', 'a5e5',
|
| 148 |
+
'a5f5', 'a5g5', 'a5h5', 'a5a6', 'a5b6', 'a5c6', 'a5a7',
|
| 149 |
+
'a5b7', 'a5c7', 'a5a8', 'a5d8', 'b5b1', 'b5f1', 'b5b2',
|
| 150 |
+
'b5e2', 'b5a3', 'b5b3', 'b5c3', 'b5d3', 'b5a4', 'b5b4',
|
| 151 |
+
'b5c4', 'b5d4', 'b5a5', 'b5c5', 'b5d5', 'b5e5', 'b5f5',
|
| 152 |
+
'b5g5', 'b5h5', 'b5a6', 'b5b6', 'b5c6', 'b5d6', 'b5a7',
|
| 153 |
+
'b5b7', 'b5c7', 'b5d7', 'b5b8', 'b5e8', 'c5c1', 'c5g1',
|
| 154 |
+
'c5c2', 'c5f2', 'c5a3', 'c5b3', 'c5c3', 'c5d3', 'c5e3',
|
| 155 |
+
'c5a4', 'c5b4', 'c5c4', 'c5d4', 'c5e4', 'c5a5', 'c5b5',
|
| 156 |
+
'c5d5', 'c5e5', 'c5f5', 'c5g5', 'c5h5', 'c5a6', 'c5b6',
|
| 157 |
+
'c5c6', 'c5d6', 'c5e6', 'c5a7', 'c5b7', 'c5c7', 'c5d7',
|
| 158 |
+
'c5e7', 'c5c8', 'c5f8', 'd5d1', 'd5h1', 'd5a2', 'd5d2',
|
| 159 |
+
'd5g2', 'd5b3', 'd5c3', 'd5d3', 'd5e3', 'd5f3', 'd5b4',
|
| 160 |
+
'd5c4', 'd5d4', 'd5e4', 'd5f4', 'd5a5', 'd5b5', 'd5c5',
|
| 161 |
+
'd5e5', 'd5f5', 'd5g5', 'd5h5', 'd5b6', 'd5c6', 'd5d6',
|
| 162 |
+
'd5e6', 'd5f6', 'd5b7', 'd5c7', 'd5d7', 'd5e7', 'd5f7',
|
| 163 |
+
'd5a8', 'd5d8', 'd5g8', 'e5a1', 'e5e1', 'e5b2', 'e5e2',
|
| 164 |
+
'e5h2', 'e5c3', 'e5d3', 'e5e3', 'e5f3', 'e5g3', 'e5c4',
|
| 165 |
+
'e5d4', 'e5e4', 'e5f4', 'e5g4', 'e5a5', 'e5b5', 'e5c5',
|
| 166 |
+
'e5d5', 'e5f5', 'e5g5', 'e5h5', 'e5c6', 'e5d6', 'e5e6',
|
| 167 |
+
'e5f6', 'e5g6', 'e5c7', 'e5d7', 'e5e7', 'e5f7', 'e5g7',
|
| 168 |
+
'e5b8', 'e5e8', 'e5h8', 'f5b1', 'f5f1', 'f5c2', 'f5f2',
|
| 169 |
+
'f5d3', 'f5e3', 'f5f3', 'f5g3', 'f5h3', 'f5d4', 'f5e4',
|
| 170 |
+
'f5f4', 'f5g4', 'f5h4', 'f5a5', 'f5b5', 'f5c5', 'f5d5',
|
| 171 |
+
'f5e5', 'f5g5', 'f5h5', 'f5d6', 'f5e6', 'f5f6', 'f5g6',
|
| 172 |
+
'f5h6', 'f5d7', 'f5e7', 'f5f7', 'f5g7', 'f5h7', 'f5c8',
|
| 173 |
+
'f5f8', 'g5c1', 'g5g1', 'g5d2', 'g5g2', 'g5e3', 'g5f3',
|
| 174 |
+
'g5g3', 'g5h3', 'g5e4', 'g5f4', 'g5g4', 'g5h4', 'g5a5',
|
| 175 |
+
'g5b5', 'g5c5', 'g5d5', 'g5e5', 'g5f5', 'g5h5', 'g5e6',
|
| 176 |
+
'g5f6', 'g5g6', 'g5h6', 'g5e7', 'g5f7', 'g5g7', 'g5h7',
|
| 177 |
+
'g5d8', 'g5g8', 'h5d1', 'h5h1', 'h5e2', 'h5h2', 'h5f3',
|
| 178 |
+
'h5g3', 'h5h3', 'h5f4', 'h5g4', 'h5h4', 'h5a5', 'h5b5',
|
| 179 |
+
'h5c5', 'h5d5', 'h5e5', 'h5f5', 'h5g5', 'h5f6', 'h5g6',
|
| 180 |
+
'h5h6', 'h5f7', 'h5g7', 'h5h7', 'h5e8', 'h5h8', 'a6a1',
|
| 181 |
+
'a6f1', 'a6a2', 'a6e2', 'a6a3', 'a6d3', 'a6a4', 'a6b4',
|
| 182 |
+
'a6c4', 'a6a5', 'a6b5', 'a6c5', 'a6b6', 'a6c6', 'a6d6',
|
| 183 |
+
'a6e6', 'a6f6', 'a6g6', 'a6h6', 'a6a7', 'a6b7', 'a6c7',
|
| 184 |
+
'a6a8', 'a6b8', 'a6c8', 'b6b1', 'b6g1', 'b6b2', 'b6f2',
|
| 185 |
+
'b6b3', 'b6e3', 'b6a4', 'b6b4', 'b6c4', 'b6d4', 'b6a5',
|
| 186 |
+
'b6b5', 'b6c5', 'b6d5', 'b6a6', 'b6c6', 'b6d6', 'b6e6',
|
| 187 |
+
'b6f6', 'b6g6', 'b6h6', 'b6a7', 'b6b7', 'b6c7', 'b6d7',
|
| 188 |
+
'b6a8', 'b6b8', 'b6c8', 'b6d8', 'c6c1', 'c6h1', 'c6c2',
|
| 189 |
+
'c6g2', 'c6c3', 'c6f3', 'c6a4', 'c6b4', 'c6c4', 'c6d4',
|
| 190 |
+
'c6e4', 'c6a5', 'c6b5', 'c6c5', 'c6d5', 'c6e5', 'c6a6',
|
| 191 |
+
'c6b6', 'c6d6', 'c6e6', 'c6f6', 'c6g6', 'c6h6', 'c6a7',
|
| 192 |
+
'c6b7', 'c6c7', 'c6d7', 'c6e7', 'c6a8', 'c6b8', 'c6c8',
|
| 193 |
+
'c6d8', 'c6e8', 'd6d1', 'd6d2', 'd6h2', 'd6a3', 'd6d3',
|
| 194 |
+
'd6g3', 'd6b4', 'd6c4', 'd6d4', 'd6e4', 'd6f4', 'd6b5',
|
| 195 |
+
'd6c5', 'd6d5', 'd6e5', 'd6f5', 'd6a6', 'd6b6', 'd6c6',
|
| 196 |
+
'd6e6', 'd6f6', 'd6g6', 'd6h6', 'd6b7', 'd6c7', 'd6d7',
|
| 197 |
+
'd6e7', 'd6f7', 'd6b8', 'd6c8', 'd6d8', 'd6e8', 'd6f8',
|
| 198 |
+
'e6e1', 'e6a2', 'e6e2', 'e6b3', 'e6e3', 'e6h3', 'e6c4',
|
| 199 |
+
'e6d4', 'e6e4', 'e6f4', 'e6g4', 'e6c5', 'e6d5', 'e6e5',
|
| 200 |
+
'e6f5', 'e6g5', 'e6a6', 'e6b6', 'e6c6', 'e6d6', 'e6f6',
|
| 201 |
+
'e6g6', 'e6h6', 'e6c7', 'e6d7', 'e6e7', 'e6f7', 'e6g7',
|
| 202 |
+
'e6c8', 'e6d8', 'e6e8', 'e6f8', 'e6g8', 'f6a1', 'f6f1',
|
| 203 |
+
'f6b2', 'f6f2', 'f6c3', 'f6f3', 'f6d4', 'f6e4', 'f6f4',
|
| 204 |
+
'f6g4', 'f6h4', 'f6d5', 'f6e5', 'f6f5', 'f6g5', 'f6h5',
|
| 205 |
+
'f6a6', 'f6b6', 'f6c6', 'f6d6', 'f6e6', 'f6g6', 'f6h6',
|
| 206 |
+
'f6d7', 'f6e7', 'f6f7', 'f6g7', 'f6h7', 'f6d8', 'f6e8',
|
| 207 |
+
'f6f8', 'f6g8', 'f6h8', 'g6b1', 'g6g1', 'g6c2', 'g6g2',
|
| 208 |
+
'g6d3', 'g6g3', 'g6e4', 'g6f4', 'g6g4', 'g6h4', 'g6e5',
|
| 209 |
+
'g6f5', 'g6g5', 'g6h5', 'g6a6', 'g6b6', 'g6c6', 'g6d6',
|
| 210 |
+
'g6e6', 'g6f6', 'g6h6', 'g6e7', 'g6f7', 'g6g7', 'g6h7',
|
| 211 |
+
'g6e8', 'g6f8', 'g6g8', 'g6h8', 'h6c1', 'h6h1', 'h6d2',
|
| 212 |
+
'h6h2', 'h6e3', 'h6h3', 'h6f4', 'h6g4', 'h6h4', 'h6f5',
|
| 213 |
+
'h6g5', 'h6h5', 'h6a6', 'h6b6', 'h6c6', 'h6d6', 'h6e6',
|
| 214 |
+
'h6f6', 'h6g6', 'h6f7', 'h6g7', 'h6h7', 'h6f8', 'h6g8',
|
| 215 |
+
'h6h8', 'a7a1', 'a7g1', 'a7a2', 'a7f2', 'a7a3', 'a7e3',
|
| 216 |
+
'a7a4', 'a7d4', 'a7a5', 'a7b5', 'a7c5', 'a7a6', 'a7b6',
|
| 217 |
+
'a7c6', 'a7b7', 'a7c7', 'a7d7', 'a7e7', 'a7f7', 'a7g7',
|
| 218 |
+
'a7h7', 'a7a8', 'a7b8', 'a7c8', 'b7b1', 'b7h1', 'b7b2',
|
| 219 |
+
'b7g2', 'b7b3', 'b7f3', 'b7b4', 'b7e4', 'b7a5', 'b7b5',
|
| 220 |
+
'b7c5', 'b7d5', 'b7a6', 'b7b6', 'b7c6', 'b7d6', 'b7a7',
|
| 221 |
+
'b7c7', 'b7d7', 'b7e7', 'b7f7', 'b7g7', 'b7h7', 'b7a8',
|
| 222 |
+
'b7b8', 'b7c8', 'b7d8', 'c7c1', 'c7c2', 'c7h2', 'c7c3',
|
| 223 |
+
'c7g3', 'c7c4', 'c7f4', 'c7a5', 'c7b5', 'c7c5', 'c7d5',
|
| 224 |
+
'c7e5', 'c7a6', 'c7b6', 'c7c6', 'c7d6', 'c7e6', 'c7a7',
|
| 225 |
+
'c7b7', 'c7d7', 'c7e7', 'c7f7', 'c7g7', 'c7h7', 'c7a8',
|
| 226 |
+
'c7b8', 'c7c8', 'c7d8', 'c7e8', 'd7d1', 'd7d2', 'd7d3',
|
| 227 |
+
'd7h3', 'd7a4', 'd7d4', 'd7g4', 'd7b5', 'd7c5', 'd7d5',
|
| 228 |
+
'd7e5', 'd7f5', 'd7b6', 'd7c6', 'd7d6', 'd7e6', 'd7f6',
|
| 229 |
+
'd7a7', 'd7b7', 'd7c7', 'd7e7', 'd7f7', 'd7g7', 'd7h7',
|
| 230 |
+
'd7b8', 'd7c8', 'd7d8', 'd7e8', 'd7f8', 'e7e1', 'e7e2',
|
| 231 |
+
'e7a3', 'e7e3', 'e7b4', 'e7e4', 'e7h4', 'e7c5', 'e7d5',
|
| 232 |
+
'e7e5', 'e7f5', 'e7g5', 'e7c6', 'e7d6', 'e7e6', 'e7f6',
|
| 233 |
+
'e7g6', 'e7a7', 'e7b7', 'e7c7', 'e7d7', 'e7f7', 'e7g7',
|
| 234 |
+
'e7h7', 'e7c8', 'e7d8', 'e7e8', 'e7f8', 'e7g8', 'f7f1',
|
| 235 |
+
'f7a2', 'f7f2', 'f7b3', 'f7f3', 'f7c4', 'f7f4', 'f7d5',
|
| 236 |
+
'f7e5', 'f7f5', 'f7g5', 'f7h5', 'f7d6', 'f7e6', 'f7f6',
|
| 237 |
+
'f7g6', 'f7h6', 'f7a7', 'f7b7', 'f7c7', 'f7d7', 'f7e7',
|
| 238 |
+
'f7g7', 'f7h7', 'f7d8', 'f7e8', 'f7f8', 'f7g8', 'f7h8',
|
| 239 |
+
'g7a1', 'g7g1', 'g7b2', 'g7g2', 'g7c3', 'g7g3', 'g7d4',
|
| 240 |
+
'g7g4', 'g7e5', 'g7f5', 'g7g5', 'g7h5', 'g7e6', 'g7f6',
|
| 241 |
+
'g7g6', 'g7h6', 'g7a7', 'g7b7', 'g7c7', 'g7d7', 'g7e7',
|
| 242 |
+
'g7f7', 'g7h7', 'g7e8', 'g7f8', 'g7g8', 'g7h8', 'h7b1',
|
| 243 |
+
'h7h1', 'h7c2', 'h7h2', 'h7d3', 'h7h3', 'h7e4', 'h7h4',
|
| 244 |
+
'h7f5', 'h7g5', 'h7h5', 'h7f6', 'h7g6', 'h7h6', 'h7a7',
|
| 245 |
+
'h7b7', 'h7c7', 'h7d7', 'h7e7', 'h7f7', 'h7g7', 'h7f8',
|
| 246 |
+
'h7g8', 'h7h8', 'a8a1', 'a8h1', 'a8a2', 'a8g2', 'a8a3',
|
| 247 |
+
'a8f3', 'a8a4', 'a8e4', 'a8a5', 'a8d5', 'a8a6', 'a8b6',
|
| 248 |
+
'a8c6', 'a8a7', 'a8b7', 'a8c7', 'a8b8', 'a8c8', 'a8d8',
|
| 249 |
+
'a8e8', 'a8f8', 'a8g8', 'a8h8', 'b8b1', 'b8b2', 'b8h2',
|
| 250 |
+
'b8b3', 'b8g3', 'b8b4', 'b8f4', 'b8b5', 'b8e5', 'b8a6',
|
| 251 |
+
'b8b6', 'b8c6', 'b8d6', 'b8a7', 'b8b7', 'b8c7', 'b8d7',
|
| 252 |
+
'b8a8', 'b8c8', 'b8d8', 'b8e8', 'b8f8', 'b8g8', 'b8h8',
|
| 253 |
+
'c8c1', 'c8c2', 'c8c3', 'c8h3', 'c8c4', 'c8g4', 'c8c5',
|
| 254 |
+
'c8f5', 'c8a6', 'c8b6', 'c8c6', 'c8d6', 'c8e6', 'c8a7',
|
| 255 |
+
'c8b7', 'c8c7', 'c8d7', 'c8e7', 'c8a8', 'c8b8', 'c8d8',
|
| 256 |
+
'c8e8', 'c8f8', 'c8g8', 'c8h8', 'd8d1', 'd8d2', 'd8d3',
|
| 257 |
+
'd8d4', 'd8h4', 'd8a5', 'd8d5', 'd8g5', 'd8b6', 'd8c6',
|
| 258 |
+
'd8d6', 'd8e6', 'd8f6', 'd8b7', 'd8c7', 'd8d7', 'd8e7',
|
| 259 |
+
'd8f7', 'd8a8', 'd8b8', 'd8c8', 'd8e8', 'd8f8', 'd8g8',
|
| 260 |
+
'd8h8', 'e8e1', 'e8e2', 'e8e3', 'e8a4', 'e8e4', 'e8b5',
|
| 261 |
+
'e8e5', 'e8h5', 'e8c6', 'e8d6', 'e8e6', 'e8f6', 'e8g6',
|
| 262 |
+
'e8c7', 'e8d7', 'e8e7', 'e8f7', 'e8g7', 'e8a8', 'e8b8',
|
| 263 |
+
'e8c8', 'e8d8', 'e8f8', 'e8g8', 'e8h8', 'f8f1', 'f8f2',
|
| 264 |
+
'f8a3', 'f8f3', 'f8b4', 'f8f4', 'f8c5', 'f8f5', 'f8d6',
|
| 265 |
+
'f8e6', 'f8f6', 'f8g6', 'f8h6', 'f8d7', 'f8e7', 'f8f7',
|
| 266 |
+
'f8g7', 'f8h7', 'f8a8', 'f8b8', 'f8c8', 'f8d8', 'f8e8',
|
| 267 |
+
'f8g8', 'f8h8', 'g8g1', 'g8a2', 'g8g2', 'g8b3', 'g8g3',
|
| 268 |
+
'g8c4', 'g8g4', 'g8d5', 'g8g5', 'g8e6', 'g8f6', 'g8g6',
|
| 269 |
+
'g8h6', 'g8e7', 'g8f7', 'g8g7', 'g8h7', 'g8a8', 'g8b8',
|
| 270 |
+
'g8c8', 'g8d8', 'g8e8', 'g8f8', 'g8h8', 'h8a1', 'h8h1',
|
| 271 |
+
'h8b2', 'h8h2', 'h8c3', 'h8h3', 'h8d4', 'h8h4', 'h8e5',
|
| 272 |
+
'h8h5', 'h8f6', 'h8g6', 'h8h6', 'h8f7', 'h8g7', 'h8h7',
|
| 273 |
+
'h8a8', 'h8b8', 'h8c8', 'h8d8', 'h8e8', 'h8f8', 'h8g8',
|
| 274 |
+
'a7a8q', 'a7a8r', 'a7a8b', 'a7b8q', 'a7b8r', 'a7b8b', 'b7a8q',
|
| 275 |
+
'b7a8r', 'b7a8b', 'b7b8q', 'b7b8r', 'b7b8b', 'b7c8q', 'b7c8r',
|
| 276 |
+
'b7c8b', 'c7b8q', 'c7b8r', 'c7b8b', 'c7c8q', 'c7c8r', 'c7c8b',
|
| 277 |
+
'c7d8q', 'c7d8r', 'c7d8b', 'd7c8q', 'd7c8r', 'd7c8b', 'd7d8q',
|
| 278 |
+
'd7d8r', 'd7d8b', 'd7e8q', 'd7e8r', 'd7e8b', 'e7d8q', 'e7d8r',
|
| 279 |
+
'e7d8b', 'e7e8q', 'e7e8r', 'e7e8b', 'e7f8q', 'e7f8r', 'e7f8b',
|
| 280 |
+
'f7e8q', 'f7e8r', 'f7e8b', 'f7f8q', 'f7f8r', 'f7f8b', 'f7g8q',
|
| 281 |
+
'f7g8r', 'f7g8b', 'g7f8q', 'g7f8r', 'g7f8b', 'g7g8q', 'g7g8r',
|
| 282 |
+
'g7g8b', 'g7h8q', 'g7h8r', 'g7h8b', 'h7g8q', 'h7g8r', 'h7g8b',
|
| 283 |
+
'h7h8q', 'h7h8r', 'h7h8b'
|
| 284 |
+
]
|
| 285 |
+
|
| 286 |
+
# White, no castling
|
| 287 |
+
_uci_to_idx_wn = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_wn))
|
| 288 |
+
|
| 289 |
+
# White, castling
|
| 290 |
+
_uci_to_idx_wc = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_wn))
|
| 291 |
+
_uci_to_idx_wc['e1g1'], _uci_to_idx_wc['e1h1'] = _uci_to_idx_wc['e1h1'], _uci_to_idx_wc['e1g1']
|
| 292 |
+
_uci_to_idx_wc['e1c1'], _uci_to_idx_wc['e1a1'] = _uci_to_idx_wc['e1a1'], _uci_to_idx_wc['e1c1']
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# Black, no castling
|
| 296 |
+
_idx_to_move_bn = []
|
| 297 |
+
for move in _idx_to_move_wn:
|
| 298 |
+
c0,r0,c1,r1,p = move[0],int(move[1]),move[2],int(move[3]),move[4:]
|
| 299 |
+
r0 = 9 - r0
|
| 300 |
+
r1 = 9 - r1
|
| 301 |
+
_idx_to_move_bn.append('{}{}{}{}{}'.format(c0,r0,c1,r1,p))
|
| 302 |
+
_uci_to_idx_bn = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_bn))
|
| 303 |
+
|
| 304 |
+
# Black, castling
|
| 305 |
+
_uci_to_idx_bc = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_bn))
|
| 306 |
+
_uci_to_idx_bc['e8g8'], _uci_to_idx_bc['e8h8'] = _uci_to_idx_bc['e8h8'], _uci_to_idx_bc['e8g8']
|
| 307 |
+
_uci_to_idx_bc['e8c8'], _uci_to_idx_bc['e8a8'] = _uci_to_idx_bc['e8a8'], _uci_to_idx_bc['e8c8']
|
| 308 |
+
|
| 309 |
+
uci_to_idx = [_uci_to_idx_wn, _uci_to_idx_wc, _uci_to_idx_bn, _uci_to_idx_bc]
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
import collections
|
| 313 |
+
import struct
|
| 314 |
+
import zlib
|
| 315 |
+
|
| 316 |
+
import chess
|
| 317 |
+
import numpy as np
|
| 318 |
+
from chess import Move
|
| 319 |
+
|
| 320 |
+
flat_planes = []
|
| 321 |
+
for i in range(256):
|
| 322 |
+
flat_planes.append(np.ones((8,8), dtype=np.uint8)*i)
|
| 323 |
+
|
| 324 |
+
LeelaBoardData = collections.namedtuple('LeelaBoardData',
|
| 325 |
+
'plane_bytes repetition '
|
| 326 |
+
'transposition_key us_ooo us_oo them_ooo them_oo '
|
| 327 |
+
'side_to_move rule50_count')
|
| 328 |
+
|
| 329 |
+
def pc_board_property(propertyname):
|
| 330 |
+
'''Create a property based on self.pc_board'''
|
| 331 |
+
def prop(self):
|
| 332 |
+
return getattr(self.pc_board, propertyname)
|
| 333 |
+
return property(prop)
|
| 334 |
+
|
| 335 |
+
class LeelaBoard:
|
| 336 |
+
turn = pc_board_property('turn')
|
| 337 |
+
move_stack = pc_board_property('move_stack')
|
| 338 |
+
_plane_bytes_struct = struct.Struct('>Q')
|
| 339 |
+
|
| 340 |
+
def __init__(self, leela_board = None, *args, **kwargs):
|
| 341 |
+
'''If leela_board is passed as an argument, return a copy'''
|
| 342 |
+
self.pc_board = chess.Board(*args, **kwargs)
|
| 343 |
+
self.lcz_stack = []
|
| 344 |
+
self._lcz_transposition_counter = collections.Counter()
|
| 345 |
+
self._lcz_push()
|
| 346 |
+
self.is_game_over = self.pc_method('is_game_over')
|
| 347 |
+
self.can_claim_draw = self.pc_method('can_claim_draw')
|
| 348 |
+
self.generate_legal_moves = self.pc_method('generate_legal_moves')
|
| 349 |
+
|
| 350 |
+
def copy(self, history=7):
|
| 351 |
+
"""Note! Currently the copy constructor uses pc_board.copy(stack=False), which makes pops impossible"""
|
| 352 |
+
cls = type(self)
|
| 353 |
+
copied = cls.__new__(cls)
|
| 354 |
+
copied.pc_board = self.pc_board.copy(stack=False)
|
| 355 |
+
copied.pc_board.stack[:] = self.pc_board.stack[-history:]
|
| 356 |
+
copied.pc_board.move_stack[:] = self.pc_board.move_stack[-history:]
|
| 357 |
+
copied.lcz_stack = self.lcz_stack[-history:]
|
| 358 |
+
copied._lcz_transposition_counter = self._lcz_transposition_counter.copy()
|
| 359 |
+
copied.is_game_over = copied.pc_method('is_game_over')
|
| 360 |
+
copied.can_claim_draw = copied.pc_method('can_claim_draw')
|
| 361 |
+
copied.generate_legal_moves = copied.pc_method('generate_legal_moves')
|
| 362 |
+
return copied
|
| 363 |
+
|
| 364 |
+
def pc_method(self, methodname):
|
| 365 |
+
'''Return attribute of self.pc_board, useful for copying method bindings'''
|
| 366 |
+
return getattr(self.pc_board, methodname)
|
| 367 |
+
|
| 368 |
+
def is_threefold(self):
|
| 369 |
+
transposition_key = self.pc_board._transposition_key()
|
| 370 |
+
return self._lcz_transposition_counter[transposition_key] >= 3
|
| 371 |
+
|
| 372 |
+
def is_fifty_moves(self):
|
| 373 |
+
return self.pc_board.halfmove_clock >= 100
|
| 374 |
+
|
| 375 |
+
def is_draw(self):
|
| 376 |
+
return self.is_threefold() or self.is_fifty_moves()
|
| 377 |
+
|
| 378 |
+
def push(self, move):
|
| 379 |
+
self.pc_board.push(move)
|
| 380 |
+
self._lcz_push()
|
| 381 |
+
|
| 382 |
+
def push_uci(self, uci):
|
| 383 |
+
# don't check for legality - it takes much longer to run...
|
| 384 |
+
# self.pc_board.push_uci(uci)
|
| 385 |
+
self.pc_board.push(Move.from_uci(uci))
|
| 386 |
+
self._lcz_push()
|
| 387 |
+
|
| 388 |
+
def push_san(self, san):
|
| 389 |
+
self.pc_board.push_san(san)
|
| 390 |
+
self._lcz_push()
|
| 391 |
+
|
| 392 |
+
def pop(self):
|
| 393 |
+
result = self.pc_board.pop()
|
| 394 |
+
_lcz_data = self.lcz_stack.pop()
|
| 395 |
+
self._lcz_transposition_counter.subtract((_lcz_data.transposition_key,))
|
| 396 |
+
return result
|
| 397 |
+
|
| 398 |
+
def _plane_bytes_iter(self):
|
| 399 |
+
"""Get plane bytes... used for _lcz_push"""
|
| 400 |
+
pack = self._plane_bytes_struct.pack
|
| 401 |
+
pieces_mask = self.pc_board.pieces_mask
|
| 402 |
+
for color in (True, False):
|
| 403 |
+
for piece_type in range(1,7):
|
| 404 |
+
byts = pack(pieces_mask(piece_type, color))
|
| 405 |
+
yield byts
|
| 406 |
+
|
| 407 |
+
def _lcz_push(self):
|
| 408 |
+
"""Push data onto the lcz data stack after pushing board moves"""
|
| 409 |
+
transposition_key = self.pc_board._transposition_key()
|
| 410 |
+
self._lcz_transposition_counter.update((transposition_key,))
|
| 411 |
+
repetitions = self._lcz_transposition_counter[transposition_key] - 1
|
| 412 |
+
# side_to_move = 0 if we're white, 1 if we're black
|
| 413 |
+
side_to_move = 0 if self.pc_board.turn else 1
|
| 414 |
+
rule50_count = self.pc_board.halfmove_clock
|
| 415 |
+
# Figure out castling rights
|
| 416 |
+
if not side_to_move:
|
| 417 |
+
# we're white
|
| 418 |
+
_c = self.pc_board.castling_rights
|
| 419 |
+
us_ooo, us_oo = (_c>>chess.A1) & 1, (_c>>chess.H1) & 1
|
| 420 |
+
them_ooo, them_oo = (_c>>chess.A8) & 1, (_c>>chess.H8) & 1
|
| 421 |
+
else:
|
| 422 |
+
# We're black
|
| 423 |
+
_c = self.pc_board.castling_rights
|
| 424 |
+
us_ooo, us_oo = (_c>>chess.A8) & 1, (_c>>chess.H8) & 1
|
| 425 |
+
them_ooo, them_oo = (_c>>chess.A1) & 1, (_c>>chess.H1) & 1
|
| 426 |
+
# Create 13 planes... 6 us, 6 them, repetitions>=1
|
| 427 |
+
plane_bytes = b''.join(self._plane_bytes_iter())
|
| 428 |
+
repetition = (repetitions>=1)
|
| 429 |
+
lcz_data = LeelaBoardData(
|
| 430 |
+
plane_bytes, repetition=repetition,
|
| 431 |
+
transposition_key=transposition_key,
|
| 432 |
+
us_ooo=us_ooo, us_oo=us_oo, them_ooo=them_ooo, them_oo=them_oo,
|
| 433 |
+
side_to_move=side_to_move, rule50_count=rule50_count
|
| 434 |
+
)
|
| 435 |
+
self.lcz_stack.append(lcz_data)
|
| 436 |
+
|
| 437 |
+
def serialize_features(self):
|
| 438 |
+
'''Get compacted bytes representation of input planes'''
|
| 439 |
+
planes = []
|
| 440 |
+
curdata = self.lcz_stack[-1]
|
| 441 |
+
bytes_false_true = bytes([False]), bytes([True])
|
| 442 |
+
bytes_per_history = 97
|
| 443 |
+
total_plane_bytes = bytes_per_history * 8
|
| 444 |
+
def bytes_iter():
|
| 445 |
+
plane_bytes_yielded = 0
|
| 446 |
+
for data in self.lcz_stack[-1:-9:-1]:
|
| 447 |
+
yield data.plane_bytes
|
| 448 |
+
yield bytes_false_true[data.repetition]
|
| 449 |
+
plane_bytes_yielded += bytes_per_history
|
| 450 |
+
# 104 total piece planes... fill in missing with 0s
|
| 451 |
+
yield bytes(total_plane_bytes - plane_bytes_yielded)
|
| 452 |
+
# Yield the rest of the constant planes
|
| 453 |
+
yield np.packbits((curdata.us_ooo,
|
| 454 |
+
curdata.us_oo,
|
| 455 |
+
curdata.them_ooo,
|
| 456 |
+
curdata.them_oo,
|
| 457 |
+
curdata.side_to_move)).tobytes()
|
| 458 |
+
yield chr(curdata.rule50_count).encode()
|
| 459 |
+
return b''.join(bytes_iter())
|
| 460 |
+
|
| 461 |
+
@classmethod
|
| 462 |
+
def deserialize_features(cls, serialized):
|
| 463 |
+
planes_stack = []
|
| 464 |
+
rule50_count = serialized[-1] # last byte is rule 50
|
| 465 |
+
board_attrs = np.unpackbits(memoryview(serialized[-2:-1])) # second to last byte
|
| 466 |
+
us_ooo, us_oo, them_ooo, them_oo, side_to_move = board_attrs[:5]
|
| 467 |
+
bytes_per_history = 97
|
| 468 |
+
for history_idx in range(0, bytes_per_history*8, bytes_per_history):
|
| 469 |
+
plane_bytes = serialized[history_idx:history_idx+96]
|
| 470 |
+
repetition = serialized[history_idx+96]
|
| 471 |
+
if not side_to_move:
|
| 472 |
+
# we're white
|
| 473 |
+
planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
|
| 474 |
+
.reshape(12, 8, 8)[::-1])
|
| 475 |
+
else:
|
| 476 |
+
# We're black
|
| 477 |
+
planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
|
| 478 |
+
.reshape(12, 8, 8)[::-1]
|
| 479 |
+
.reshape(2,6,8,8)[::-1,:,::-1]
|
| 480 |
+
.reshape(12, 8,8))
|
| 481 |
+
planes_stack.append(planes)
|
| 482 |
+
planes_stack.append([flat_planes[repetition]])
|
| 483 |
+
planes_stack.append([flat_planes[us_ooo],
|
| 484 |
+
flat_planes[us_oo],
|
| 485 |
+
flat_planes[them_ooo],
|
| 486 |
+
flat_planes[them_oo],
|
| 487 |
+
flat_planes[side_to_move],
|
| 488 |
+
flat_planes[rule50_count],
|
| 489 |
+
flat_planes[0],
|
| 490 |
+
flat_planes[1]])
|
| 491 |
+
planes = np.concatenate(planes_stack)
|
| 492 |
+
return planes
|
| 493 |
+
|
| 494 |
+
def lcz_features(self):
|
| 495 |
+
'''Get neural network input planes as uint8'''
|
| 496 |
+
# print(list(self._planes_iter()))
|
| 497 |
+
planes_stack = []
|
| 498 |
+
curdata = self.lcz_stack[-1]
|
| 499 |
+
planes_yielded = 0
|
| 500 |
+
for data in self.lcz_stack[-1:-9:-1]:
|
| 501 |
+
plane_bytes = data.plane_bytes
|
| 502 |
+
if not curdata.side_to_move:
|
| 503 |
+
# we're white
|
| 504 |
+
planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
|
| 505 |
+
.reshape(12, 8, 8)[::-1])
|
| 506 |
+
else:
|
| 507 |
+
# We're black
|
| 508 |
+
planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
|
| 509 |
+
.reshape(12, 8, 8)[::-1]
|
| 510 |
+
.reshape(2,6,8,8)[::-1,:,::-1]
|
| 511 |
+
.reshape(12, 8,8))
|
| 512 |
+
planes_stack.append(planes)
|
| 513 |
+
planes_stack.append([flat_planes[data.repetition]])
|
| 514 |
+
planes_yielded += 13
|
| 515 |
+
empty_planes = [flat_planes[0] for _ in range(104-planes_yielded)]
|
| 516 |
+
if empty_planes:
|
| 517 |
+
planes_stack.append(empty_planes)
|
| 518 |
+
# Yield the rest of the constant planes
|
| 519 |
+
planes_stack.append([flat_planes[curdata.us_ooo],
|
| 520 |
+
flat_planes[curdata.us_oo],
|
| 521 |
+
flat_planes[curdata.them_ooo],
|
| 522 |
+
flat_planes[curdata.them_oo],
|
| 523 |
+
flat_planes[curdata.side_to_move],
|
| 524 |
+
flat_planes[curdata.rule50_count],
|
| 525 |
+
flat_planes[0],
|
| 526 |
+
flat_planes[1]])
|
| 527 |
+
planes = np.concatenate(planes_stack)
|
| 528 |
+
return planes
|
| 529 |
+
|
| 530 |
+
def lcz_uci_to_idx(self, uci_list):
|
| 531 |
+
# Return list of NN policy output indexes for this board position, given uci_list
|
| 532 |
+
|
| 533 |
+
# TODO: Perhaps it's possible to just add the uci knight promotion move to the index dict
|
| 534 |
+
# currently knight promotions are not in the dict
|
| 535 |
+
uci_list = [uci.rstrip('n') for uci in uci_list]
|
| 536 |
+
|
| 537 |
+
data = self.lcz_stack[-1]
|
| 538 |
+
# uci_to_idx_index =
|
| 539 |
+
# White, no-castling => 0
|
| 540 |
+
# White, castling => 1
|
| 541 |
+
# Black, no-castling => 2
|
| 542 |
+
# Black, castling => 3
|
| 543 |
+
uci_to_idx_index = (data.us_ooo | data.us_oo) + 2*data.side_to_move
|
| 544 |
+
uci_idx_dct = uci_to_idx[uci_to_idx_index]
|
| 545 |
+
return [uci_idx_dct[m] for m in uci_list]
|
| 546 |
+
|
| 547 |
+
@classmethod
|
| 548 |
+
def compress_features(cls, features):
|
| 549 |
+
"""Compress a features array as returned from lcz_features method"""
|
| 550 |
+
features_8 = features.astype(np.uint8)
|
| 551 |
+
# Simple compression would do this...
|
| 552 |
+
# return zlib.compress(features_8)
|
| 553 |
+
piece_plane_bytes = np.packbits(features_8[:-8]).tobytes()
|
| 554 |
+
scalar_bytes = features_8[-8:][:,0,0].tobytes()
|
| 555 |
+
compressed = zlib.compress(piece_plane_bytes + scalar_bytes)
|
| 556 |
+
return compressed
|
| 557 |
+
|
| 558 |
+
@classmethod
|
| 559 |
+
def decompress_features(cls, compressed_features):
|
| 560 |
+
"""Decompress a compressed features array from compress_features"""
|
| 561 |
+
decompressed = zlib.decompress(compressed_features)
|
| 562 |
+
# Simple decompression would do this
|
| 563 |
+
# return np.frombuffer(decompressed, dtype=np.uint8).astype(np.float32).reshape(-1,8,8)
|
| 564 |
+
piece_plane_bytes = decompressed[:-8]
|
| 565 |
+
scalar_bytes = decompressed[-8:]
|
| 566 |
+
piece_plane_arr = np.unpackbits(memoryview(piece_plane_bytes))
|
| 567 |
+
scalar_arr = np.frombuffer(scalar_bytes, dtype=np.uint8).repeat(64)
|
| 568 |
+
result = np.concatenate((piece_plane_arr, scalar_arr)).astype(np.float32).reshape(-1,8,8)
|
| 569 |
+
return result
|
| 570 |
+
|
| 571 |
+
def unicode(self):
|
| 572 |
+
if self.pc_board.is_game_over() or self.is_draw():
|
| 573 |
+
result = self.pc_board.result(claim_draw=True)
|
| 574 |
+
turnstring = 'Result: {}'.format(result)
|
| 575 |
+
else:
|
| 576 |
+
turnstring = 'Turn: {}'.format('White' if self.pc_board.turn else 'Black')
|
| 577 |
+
boardstr = self.pc_board.unicode() + "\n" + turnstring
|
| 578 |
+
return boardstr
|
| 579 |
+
|
| 580 |
+
def __repr__(self):
|
| 581 |
+
return "LeelaBoard('{}')".format(self.pc_board.fen())
|
| 582 |
+
|
| 583 |
+
def _repr_svg_(self):
|
| 584 |
+
return self.pc_board._repr_svg_()
|
| 585 |
+
|
| 586 |
+
def __str__(self):
|
| 587 |
+
if self.pc_board.is_game_over() or self.is_draw():
|
| 588 |
+
result = self.pc_board.result(claim_draw=True)
|
| 589 |
+
turnstring = 'Result: {}'.format(result)
|
| 590 |
+
else:
|
| 591 |
+
turnstring = 'Turn: {}'.format('White' if self.pc_board.turn else 'Black')
|
| 592 |
+
boardstr = self.pc_board.__str__() + "\n" + turnstring
|
| 593 |
+
return boardstr
|
| 594 |
+
|
| 595 |
+
def __eq__(self, other):
|
| 596 |
+
return self.get_hash_key() == other.get_hash_key()
|
| 597 |
+
|
| 598 |
+
def __hash__(self):
|
| 599 |
+
return hash(self.get_hash_key())
|
| 600 |
+
|
| 601 |
+
def get_hash_key(self):
|
| 602 |
+
transposition_key = self.pc_board._transposition_key()
|
| 603 |
+
return (transposition_key +
|
| 604 |
+
(self._lcz_transposition_counter[transposition_key], self.pc_board.halfmove_clock) +
|
| 605 |
+
tuple(self.pc_board.move_stack[-7:])
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
# lb = LeelaBoard()
|
| 609 |
+
# lb.push_uci('c2c4')
|
| 610 |
+
#lb.push_uci('c7c5')
|
| 611 |
+
#lb.push_uci('d2d3')
|
| 612 |
+
#lb.push_uci('c2c4')
|
| 613 |
+
#lb.push_uci('b8c6')
|
| 614 |
+
# saved_planes = planes
|
| 615 |
+
# planes = lb.features()
|
| 616 |
+
# output = leela_net(torch.from_numpy(planes).unsqueeze(0))
|
| 617 |
+
# output
|
leela_utils.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
models/leela-large-official.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a439603c5668cd8b05cc9273266a905ac4db689990ca9fe72d0e247102a7c9c3
|
| 3 |
+
size 741143739
|
python_chess_customized_svg.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file has been copied and slightly modified from python-chess library,
|
| 2 |
+
# Copyright (C) 2016-2020 Niklas Fiekas <niklas.fiekas@backscattering.de>.
|
| 3 |
+
|
| 4 |
+
# This program is free software: you can redistribute it and/or modify
|
| 5 |
+
# it under the terms of the GNU General Public License as published by
|
| 6 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 7 |
+
# (at your option) any later version.
|
| 8 |
+
#
|
| 9 |
+
# This program is distributed in the hope that it will be useful,
|
| 10 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 11 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 12 |
+
# GNU General Public License for more details.
|
| 13 |
+
#
|
| 14 |
+
# You should have received a copy of the GNU General Public License
|
| 15 |
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
| 16 |
+
|
| 17 |
+
# Piece vector graphics are copyright (C) Colin M.L. Burnett
|
| 18 |
+
# <https://en.wikipedia.org/wiki/User:Cburnett> and also licensed under the
|
| 19 |
+
# GNU General Public License.
|
| 20 |
+
|
| 21 |
+
import chess
|
| 22 |
+
import math
|
| 23 |
+
|
| 24 |
+
import xml.etree.ElementTree as ET
|
| 25 |
+
|
| 26 |
+
from typing import Iterable, Optional, Tuple, Union
|
| 27 |
+
|
| 28 |
+
SQUARE_SIZE = 45
|
| 29 |
+
MARGIN = 20
|
| 30 |
+
|
| 31 |
+
PIECES = {
|
| 32 |
+
"b": """<g id="black-bishop" class="black bishop" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zm6-4c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z" fill="#000" stroke-linecap="butt"/><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke="#fff" stroke-linejoin="miter"/></g>""",
|
| 33 |
+
# noqa: E501
|
| 34 |
+
"k": """<g id="black-king" class="black king" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#000" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#000"/><path d="M20 8h5" stroke-linejoin="miter"/><path d="M32 29.5s8.5-4 6.03-9.65C34.15 14 25 18 22.5 24.5l.01 2.1-.01-2.1C20 18 9.906 14 6.997 19.85c-2.497 5.65 4.853 9 4.853 9M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0" stroke="#fff"/></g>""",
|
| 35 |
+
# noqa: E501
|
| 36 |
+
"n": """<g id="black-knight" class="black knight" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#000000; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#000000; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#ececec; stroke:#ececec;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#ececec; stroke:#ececec;"/><path d="M 24.55,10.4 L 24.1,11.85 L 24.6,12 C 27.75,13 30.25,14.49 32.5,18.75 C 34.75,23.01 35.75,29.06 35.25,39 L 35.2,39.5 L 37.45,39.5 L 37.5,39 C 38,28.94 36.62,22.15 34.25,17.66 C 31.88,13.17 28.46,11.02 25.06,10.5 L 24.55,10.4 z " style="fill:#ececec; stroke:none;"/></g>""",
|
| 37 |
+
# noqa: E501
|
| 38 |
+
"p": """<g id="black-pawn" class="black pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" stroke="#000" stroke-width="1.5" stroke-linecap="round"/></g>""",
|
| 39 |
+
# noqa: E501
|
| 40 |
+
"q": """<g id="black-queen" class="black queen" fill="#000" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><g fill="#000" stroke="none"><circle cx="6" cy="12" r="2.75"/><circle cx="14" cy="9" r="2.75"/><circle cx="22.5" cy="8" r="2.75"/><circle cx="31" cy="9" r="2.75"/><circle cx="39" cy="12" r="2.75"/></g><path d="M9 26c8.5-1.5 21-1.5 27 0l2.5-12.5L31 25l-.3-14.1-5.2 13.6-3-14.5-3 14.5-5.2-13.6L14 25 6.5 13.5 9 26zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11 38.5a35 35 1 0 0 23 0" fill="none" stroke-linecap="butt"/><path d="M11 29a35 35 1 0 1 23 0M12.5 31.5h20M11.5 34.5a35 35 1 0 0 22 0M10.5 37.5a35 35 1 0 0 24 0" fill="none" stroke="#fff"/></g>""",
|
| 41 |
+
# noqa: E501
|
| 42 |
+
"r": """<g id="black-rook" class="black rook" fill="#000" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12.5 32l1.5-2.5h17l1.5 2.5h-20zM12 36v-4h21v4H12z" stroke-linecap="butt"/><path d="M14 29.5v-13h17v13H14z" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M14 16.5L11 14h23l-3 2.5H14zM11 14V9h4v2h5V9h5v2h5V9h4v5H11z" stroke-linecap="butt"/><path d="M12 35.5h21M13 31.5h19M14 29.5h17M14 16.5h17M11 14h23" fill="none" stroke="#fff" stroke-width="1" stroke-linejoin="miter"/></g>""",
|
| 43 |
+
# noqa: E501
|
| 44 |
+
"B": """<g id="white-bishop" class="white bishop" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><g fill="#fff" stroke-linecap="butt"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zM15 32c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z"/></g><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke-linejoin="miter"/></g>""",
|
| 45 |
+
# noqa: E501
|
| 46 |
+
"K": """<g id="white-king" class="white king" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6M20 8h5" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#fff" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#fff"/><path d="M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0"/></g>""",
|
| 47 |
+
# noqa: E501
|
| 48 |
+
"N": """<g id="white-knight" class="white knight" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#ffffff; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#ffffff; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#000000; stroke:#000000;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#000000; stroke:#000000;"/></g>""",
|
| 49 |
+
# noqa: E501
|
| 50 |
+
"P": """<g id="white-pawn" class="white pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" fill="#fff" stroke="#000" stroke-width="1.5" stroke-linecap="round"/></g>""",
|
| 51 |
+
# noqa: E501
|
| 52 |
+
"Q": """<g id="white-queen" class="white queen" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M8 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM24.5 7.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM41 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM16 8.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM33 9a2 2 0 1 1-4 0 2 2 0 1 1 4 0z"/><path d="M9 26c8.5-1.5 21-1.5 27 0l2-12-7 11V11l-5.5 13.5-3-15-3 15-5.5-14V25L7 14l2 12zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11.5 30c3.5-1 18.5-1 22 0M12 33.5c6-1 15-1 21 0" fill="none"/></g>""",
|
| 53 |
+
# noqa: E501
|
| 54 |
+
"R": """<g id="white-rook" class="white rook" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12 36v-4h21v4H12zM11 14V9h4v2h5V9h5v2h5V9h4v5" stroke-linecap="butt"/><path d="M34 14l-3 3H14l-3-3"/><path d="M31 17v12.5H14V17" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M31 29.5l1.5 2.5h-20l1.5-2.5"/><path d="M11 14h23" fill="none" stroke-linejoin="miter"/></g>""",
|
| 55 |
+
# noqa: E501
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
PIECES = {
|
| 59 |
+
"b": """<g id="black-bishop" class="black bishop" fill="none" fill-rule="evenodd" stroke="#fff" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zm6-4c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z" fill="#000" stroke-linecap="butt"/><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke="#fff" stroke-linejoin="miter"/></g>""",
|
| 60 |
+
# noqa: E501
|
| 61 |
+
"k": """<g id="black-king" class="black king" fill="none" fill-rule="evenodd" stroke="#fff" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#000" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#000"/><path d="M20 8h5" stroke-linejoin="miter"/><path d="M32 29.5s8.5-4 6.03-9.65C34.15 14 25 18 22.5 24.5l.01 2.1-.01-2.1C20 18 9.906 14 6.997 19.85c-2.497 5.65 4.853 9 4.853 9M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0" stroke="#fff"/></g>""",
|
| 62 |
+
# noqa: E501
|
| 63 |
+
"n": """<g id="black-knight" class="black knight" fill="none" fill-rule="evenodd" stroke="#fff" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#000000; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#000000; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#ececec; stroke:#ececec;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#ececec; stroke:#ececec;"/><path d="M 24.55,10.4 L 24.1,11.85 L 24.6,12 C 27.75,13 30.25,14.49 32.5,18.75 C 34.75,23.01 35.75,29.06 35.25,39 L 35.2,39.5 L 37.45,39.5 L 37.5,39 C 38,28.94 36.62,22.15 34.25,17.66 C 31.88,13.17 28.46,11.02 25.06,10.5 L 24.55,10.4 z " style="fill:#ececec; stroke:none;"/></g>""",
|
| 64 |
+
# noqa: E501
|
| 65 |
+
"p": """<g id="black-pawn" class="black pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" stroke="#fff" stroke-width="1.0" stroke-linecap="round"/></g>""",
|
| 66 |
+
# noqa: E501
|
| 67 |
+
"q": """<g id="black-queen" class="black queen" fill="#000" fill-rule="evenodd" stroke="#000" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><g fill="#000" stroke="none"><circle cx="6" cy="12" r="2.75"/><circle cx="14" cy="9" r="2.75"/><circle cx="22.5" cy="8" r="2.75"/><circle cx="31" cy="9" r="2.75"/><circle cx="39" cy="12" r="2.75"/></g><path d="M9 26c8.5-1.5 21-1.5 27 0l2.5-12.5L31 25l-.3-14.1-5.2 13.6-3-14.5-3 14.5-5.2-13.6L14 25 6.5 13.5 9 26zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11 38.5a35 35 1 0 0 23 0" fill="none" stroke-linecap="butt"/><path d="M11 29a35 35 1 0 1 23 0M12.5 31.5h20M11.5 34.5a35 35 1 0 0 22 0M10.5 37.5a35 35 1 0 0 24 0" fill="none" stroke="#fff"/></g>""",
|
| 68 |
+
# noqa: E501
|
| 69 |
+
"r": """<g id="black-rook" class="black rook" fill="#000" fill-rule="evenodd" stroke="#fff" stroke-width="0.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12.5 32l1.5-2.5h17l1.5 2.5h-20zM12 36v-4h21v4H12z" stroke-linecap="butt"/><path d="M14 29.5v-13h17v13H14z" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M14 16.5L11 14h23l-3 2.5H14zM11 14V9h4v2h5V9h5v2h5V9h4v5H11z" stroke-linecap="butt"/><path d="M12 35.5h21M13 31.5h19M14 29.5h17M14 16.5h17M11 14h23" fill="none" stroke="#fff" stroke-width="1" stroke-linejoin="miter"/></g>""",
|
| 70 |
+
# noqa: E501
|
| 71 |
+
"B": """<g id="white-bishop" class="white bishop" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><g fill="#fff" stroke-linecap="butt"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zM15 32c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z"/></g><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke-linejoin="miter"/></g>""",
|
| 72 |
+
# noqa: E501
|
| 73 |
+
"K": """<g id="white-king" class="white king" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6M20 8h5" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#fff" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#fff"/><path d="M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0"/></g>""",
|
| 74 |
+
# noqa: E501
|
| 75 |
+
"N": """<g id="white-knight" class="white knight" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#ffffff; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#ffffff; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#000000; stroke:#000000;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#000000; stroke:#000000;"/></g>""",
|
| 76 |
+
# noqa: E501
|
| 77 |
+
"P": """<g id="white-pawn" class="white pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" fill="#fff" stroke="#000" stroke-width="1.5" stroke-linecap="round"/></g>""",
|
| 78 |
+
# noqa: E501
|
| 79 |
+
"Q": """<g id="white-queen" class="white queen" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M8 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM24.5 7.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM41 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM16 8.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM33 9a2 2 0 1 1-4 0 2 2 0 1 1 4 0z"/><path d="M9 26c8.5-1.5 21-1.5 27 0l2-12-7 11V11l-5.5 13.5-3-15-3 15-5.5-14V25L7 14l2 12zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11.5 30c3.5-1 18.5-1 22 0M12 33.5c6-1 15-1 21 0" fill="none"/></g>""",
|
| 80 |
+
# noqa: E501
|
| 81 |
+
"R": """<g id="white-rook" class="white rook" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12 36v-4h21v4H12zM11 14V9h4v2h5V9h5v2h5V9h4v5" stroke-linecap="butt"/><path d="M34 14l-3 3H14l-3-3"/><path d="M31 17v12.5H14V17" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M31 29.5l1.5 2.5h-20l1.5-2.5"/><path d="M11 14h23" fill="none" stroke-linejoin="miter"/></g>""",
|
| 82 |
+
# noqa: E501
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
XX = """<g id="xx"><path d="M35.865 9.135a1.89 1.89 0 0 1 0 2.673L25.173 22.5l10.692 10.692a1.89 1.89 0 0 1 0 2.673 1.89 1.89 0 0 1-2.673 0L22.5 25.173 11.808 35.865a1.89 1.89 0 0 1-2.673 0 1.89 1.89 0 0 1 0-2.673L19.827 22.5 9.135 11.808a1.89 1.89 0 0 1 0-2.673 1.89 1.89 0 0 1 2.673 0L22.5 19.827 33.192 9.135a1.89 1.89 0 0 1 2.673 0z" fill="#000" stroke="#fff" stroke-width="1.688"/></g>""" # noqa: E501
|
| 86 |
+
|
| 87 |
+
CHECK_GRADIENT = """<radialGradient id="check_gradient"><stop offset="0%" stop-color="#ff0000" stop-opacity="1.0" /><stop offset="50%" stop-color="#e70000" stop-opacity="1.0" /><stop offset="100%" stop-color="#9e0000" stop-opacity="0.0" /></radialGradient>""" # noqa: E501
|
| 88 |
+
|
| 89 |
+
DEFAULT_COLORS = {
|
| 90 |
+
"square light": "#ffce9e",
|
| 91 |
+
"square dark": "#d18b47",
|
| 92 |
+
"square dark lastmove": "#aaa23b",
|
| 93 |
+
"square light lastmove": "#cdd16a",
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
class Arrow:
|
| 97 |
+
"""Details of an arrow to be drawn."""
|
| 98 |
+
|
| 99 |
+
def __init__(self, tail: chess.Square, head: chess.Square, *, color: str = "#888", annotation: str = '') -> None:
|
| 100 |
+
self.tail = tail
|
| 101 |
+
self.head = head
|
| 102 |
+
self.color = color
|
| 103 |
+
self.annotation = annotation
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class SvgWrapper(str):
|
| 107 |
+
def _repr_svg_(self) -> "SvgWrapper":
|
| 108 |
+
return self
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _svg(viewbox: int, size: Optional[int]) -> ET.Element:
|
| 112 |
+
svg = ET.Element("svg", {
|
| 113 |
+
"xmlns": "http://www.w3.org/2000/svg",
|
| 114 |
+
"version": "1.1",
|
| 115 |
+
"xmlns:xlink": "http://www.w3.org/1999/xlink",
|
| 116 |
+
"viewBox": f"0 0 {viewbox:d} {viewbox:d}",
|
| 117 |
+
})
|
| 118 |
+
|
| 119 |
+
if size is not None:
|
| 120 |
+
svg.set("width", str(size))
|
| 121 |
+
svg.set("height", str(size))
|
| 122 |
+
|
| 123 |
+
return svg
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _text(content: str, x: int, y: int, width: int, height: int) -> ET.Element:
|
| 127 |
+
t = ET.Element("text", {
|
| 128 |
+
"x": str(x + width // 2),
|
| 129 |
+
"y": str(y + height // 2),
|
| 130 |
+
"font-size": str(max(1, int(min(width, height) * 0.7))),
|
| 131 |
+
"text-anchor": "middle",
|
| 132 |
+
"alignment-baseline": "middle",
|
| 133 |
+
})
|
| 134 |
+
t.text = content
|
| 135 |
+
return t
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def piece(piece: chess.Piece, size: Optional[int] = None) -> str:
|
| 139 |
+
"""
|
| 140 |
+
Renders the given :class:`chess.Piece` as an SVG image.
|
| 141 |
+
>>> import chess
|
| 142 |
+
>>> import chess.svg
|
| 143 |
+
>>>
|
| 144 |
+
>>> chess.svg.piece(chess.Piece.from_symbol("R")) # doctest: +SKIP
|
| 145 |
+
.. image:: ../docs/wR.svg
|
| 146 |
+
"""
|
| 147 |
+
svg = _svg(SQUARE_SIZE, size)
|
| 148 |
+
svg.append(ET.fromstring(PIECES[piece.symbol()]))
|
| 149 |
+
return SvgWrapper(ET.tostring(svg).decode("utf-8"))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def board(board: Optional[chess.BaseBoard] = None, *,
|
| 153 |
+
squares: Optional[chess.IntoSquareSet] = None,
|
| 154 |
+
flipped: bool = False,
|
| 155 |
+
coordinates: bool = True,
|
| 156 |
+
lastmove: Optional[chess.Move] = None,
|
| 157 |
+
check: Optional[chess.Square] = None,
|
| 158 |
+
arrows: Iterable[Union[Arrow, Tuple[chess.Square, chess.Square]]] = (),
|
| 159 |
+
size: Optional[int] = None,
|
| 160 |
+
style: Optional[str] = None,
|
| 161 |
+
square_colors: Iterable[str] = (), #TODO: remove as it is not needed anymore
|
| 162 |
+
only_pieces: bool = False) -> str:
|
| 163 |
+
"""
|
| 164 |
+
Renders a board with pieces and/or selected squares as an SVG image.
|
| 165 |
+
:param board: A :class:`chess.BaseBoard` for a chessboard with pieces or
|
| 166 |
+
``None`` (the default) for a chessboard without pieces.
|
| 167 |
+
:param squares: A :class:`chess.SquareSet` with selected squares.
|
| 168 |
+
:param flipped: Pass ``True`` to flip the board.
|
| 169 |
+
:param coordinates: Pass ``False`` to disable coordinates in the margin.
|
| 170 |
+
:param lastmove: A :class:`chess.Move` to be highlighted.
|
| 171 |
+
:param check: A square to be marked as check.
|
| 172 |
+
:param arrows: A list of :class:`~chess.svg.Arrow` objects like
|
| 173 |
+
``[chess.svg.Arrow(chess.E2, chess.E4)]`` or a list of tuples like
|
| 174 |
+
``[(chess.E2, chess.E4)]``. An arrow from a square pointing to the same
|
| 175 |
+
square is drawn as a circle, like ``[(chess.E2, chess.E2)]``.
|
| 176 |
+
:param size: The size of the image in pixels (e.g., ``400`` for a 400 by
|
| 177 |
+
400 board) or ``None`` (the default) for no size limit.
|
| 178 |
+
:param style: A CSS stylesheet to include in the SVG image.
|
| 179 |
+
>>> import chess
|
| 180 |
+
>>> import chess.svg
|
| 181 |
+
>>>
|
| 182 |
+
>>> board = chess.Board("8/8/8/8/4N3/8/8/8 w - - 0 1")
|
| 183 |
+
>>> squares = board.attacks(chess.E4)
|
| 184 |
+
>>> chess.svg.board(board=board, squares=squares) # doctest: +SKIP
|
| 185 |
+
.. image:: ../docs/Ne4.svg
|
| 186 |
+
"""
|
| 187 |
+
margin = MARGIN if coordinates else 0
|
| 188 |
+
svg = _svg(8 * SQUARE_SIZE + 2 * margin, size)
|
| 189 |
+
|
| 190 |
+
if style:
|
| 191 |
+
ET.SubElement(svg, "style").text = style
|
| 192 |
+
|
| 193 |
+
defs = ET.SubElement(svg, "defs")
|
| 194 |
+
if board:
|
| 195 |
+
for piece_color in chess.COLORS:
|
| 196 |
+
for piece_type in chess.PIECE_TYPES:
|
| 197 |
+
if board.pieces_mask(piece_type, piece_color):
|
| 198 |
+
defs.append(ET.fromstring(PIECES[chess.Piece(piece_type, piece_color).symbol()]))
|
| 199 |
+
|
| 200 |
+
squares = chess.SquareSet(squares) if squares else chess.SquareSet()
|
| 201 |
+
if squares:
|
| 202 |
+
defs.append(ET.fromstring(XX))
|
| 203 |
+
|
| 204 |
+
if check is not None and not only_pieces:
|
| 205 |
+
defs.append(ET.fromstring(CHECK_GRADIENT))
|
| 206 |
+
|
| 207 |
+
for square, bb in enumerate(chess.BB_SQUARES):
|
| 208 |
+
file_index = chess.square_file(square)
|
| 209 |
+
rank_index = chess.square_rank(square)
|
| 210 |
+
|
| 211 |
+
x = (file_index if not flipped else 7 - file_index) * SQUARE_SIZE + margin
|
| 212 |
+
y = (7 - rank_index if not flipped else rank_index) * SQUARE_SIZE + margin
|
| 213 |
+
|
| 214 |
+
cls = ["square", "light" if chess.BB_LIGHT_SQUARES & bb else "dark"]
|
| 215 |
+
if lastmove and square in [lastmove.from_square, lastmove.to_square]:
|
| 216 |
+
cls.append("lastmove")
|
| 217 |
+
if square_colors == ():
|
| 218 |
+
fill_color = DEFAULT_COLORS[" ".join(cls)]
|
| 219 |
+
else:
|
| 220 |
+
fill_color = square_colors[square]
|
| 221 |
+
|
| 222 |
+
cls.append(chess.SQUARE_NAMES[square])
|
| 223 |
+
if not only_pieces:
|
| 224 |
+
ET.SubElement(svg, "rect", {
|
| 225 |
+
"x": str(x),
|
| 226 |
+
"y": str(y),
|
| 227 |
+
"width": str(SQUARE_SIZE),
|
| 228 |
+
"height": str(SQUARE_SIZE),
|
| 229 |
+
"class": " ".join(cls),
|
| 230 |
+
"stroke": "none",
|
| 231 |
+
"fill": fill_color,
|
| 232 |
+
})
|
| 233 |
+
|
| 234 |
+
if square == check:
|
| 235 |
+
ET.SubElement(svg, "rect", {
|
| 236 |
+
"x": str(x),
|
| 237 |
+
"y": str(y),
|
| 238 |
+
"width": str(SQUARE_SIZE),
|
| 239 |
+
"height": str(SQUARE_SIZE),
|
| 240 |
+
"class": "check",
|
| 241 |
+
"fill": "url(#check_gradient)",
|
| 242 |
+
})
|
| 243 |
+
|
| 244 |
+
# Render pieces.
|
| 245 |
+
if board is not None:
|
| 246 |
+
piece = board.piece_at(square)
|
| 247 |
+
if piece:
|
| 248 |
+
ET.SubElement(svg, "use", {
|
| 249 |
+
"xlink:href": f"#{chess.COLOR_NAMES[piece.color]}-{chess.PIECE_NAMES[piece.piece_type]}",
|
| 250 |
+
"transform": f"translate({x:d}, {y:d})",
|
| 251 |
+
})
|
| 252 |
+
|
| 253 |
+
# Render selected squares.
|
| 254 |
+
if squares is not None and square in squares:
|
| 255 |
+
#ET.SubElement(svg, "use", {
|
| 256 |
+
# "xlink:href": "#xx",
|
| 257 |
+
# "x": str(x),
|
| 258 |
+
# "y": str(y),
|
| 259 |
+
#})
|
| 260 |
+
ET.SubElement(svg, "rect", {
|
| 261 |
+
"x": str(x),
|
| 262 |
+
"y": str(y),
|
| 263 |
+
"width": str(SQUARE_SIZE),
|
| 264 |
+
"height": str(SQUARE_SIZE),
|
| 265 |
+
"class": "check",
|
| 266 |
+
"fill": "none",
|
| 267 |
+
"stroke": "#FF0000",
|
| 268 |
+
"stroke-width": "5.0",
|
| 269 |
+
"rx": "2.5",
|
| 270 |
+
"opacity": "0.60"
|
| 271 |
+
})
|
| 272 |
+
|
| 273 |
+
if coordinates:
|
| 274 |
+
for file_index, file_name in enumerate(chess.FILE_NAMES):
|
| 275 |
+
x = (file_index if not flipped else 7 - file_index) * SQUARE_SIZE + margin
|
| 276 |
+
svg.append(_text(file_name, x, 0, SQUARE_SIZE, margin))
|
| 277 |
+
svg.append(_text(file_name, x, margin + 8 * SQUARE_SIZE, SQUARE_SIZE, margin))
|
| 278 |
+
for rank_index, rank_name in enumerate(chess.RANK_NAMES):
|
| 279 |
+
y = (7 - rank_index if not flipped else rank_index) * SQUARE_SIZE + margin
|
| 280 |
+
svg.append(_text(rank_name, 0, y, margin, SQUARE_SIZE))
|
| 281 |
+
svg.append(_text(rank_name, margin + 8 * SQUARE_SIZE, y, margin, SQUARE_SIZE))
|
| 282 |
+
|
| 283 |
+
for arrow in arrows:
|
| 284 |
+
try:
|
| 285 |
+
tail, head, color, annotation = arrow.tail, arrow.head, arrow.color, arrow.annotation # type: ignore
|
| 286 |
+
except AttributeError:
|
| 287 |
+
tail, head = arrow # type: ignore
|
| 288 |
+
color = "#888"
|
| 289 |
+
annotation = ''
|
| 290 |
+
|
| 291 |
+
tail_file = chess.square_file(tail)
|
| 292 |
+
tail_rank = chess.square_rank(tail)
|
| 293 |
+
head_file = chess.square_file(head)
|
| 294 |
+
head_rank = chess.square_rank(head)
|
| 295 |
+
|
| 296 |
+
xtail = margin + (tail_file + 0.5 if not flipped else 7.5 - tail_file) * SQUARE_SIZE
|
| 297 |
+
ytail = margin + (7.5 - tail_rank if not flipped else tail_rank + 0.5) * SQUARE_SIZE
|
| 298 |
+
xhead = margin + (head_file + 0.5 if not flipped else 7.5 - head_file) * SQUARE_SIZE
|
| 299 |
+
yhead = margin + (7.5 - head_rank if not flipped else head_rank + 0.5) * SQUARE_SIZE
|
| 300 |
+
|
| 301 |
+
if (head_file, head_rank) == (tail_file, tail_rank):
|
| 302 |
+
ET.SubElement(svg, "circle", {
|
| 303 |
+
"cx": str(xhead),
|
| 304 |
+
"cy": str(yhead),
|
| 305 |
+
"r": str(SQUARE_SIZE * 0.9 / 2),
|
| 306 |
+
"stroke-width": str(SQUARE_SIZE * 0.1),
|
| 307 |
+
"stroke": color,
|
| 308 |
+
"fill": "none",
|
| 309 |
+
"opacity": "0.5",
|
| 310 |
+
"class": "circle",
|
| 311 |
+
})
|
| 312 |
+
else:
|
| 313 |
+
# marker_size = 0.75 * SQUARE_SIZE
|
| 314 |
+
# marker_margin = 0.1 * SQUARE_SIZE
|
| 315 |
+
marker_size = 0.5 * SQUARE_SIZE
|
| 316 |
+
marker_margin = 0.05 * SQUARE_SIZE
|
| 317 |
+
|
| 318 |
+
dx, dy = xhead - xtail, yhead - ytail
|
| 319 |
+
hypot = math.hypot(dx, dy)
|
| 320 |
+
|
| 321 |
+
shaft_x = xhead - dx * (marker_size + marker_margin) / hypot
|
| 322 |
+
shaft_y = yhead - dy * (marker_size + marker_margin) / hypot
|
| 323 |
+
|
| 324 |
+
xtip = xhead - dx * marker_margin / hypot
|
| 325 |
+
ytip = yhead - dy * marker_margin / hypot
|
| 326 |
+
|
| 327 |
+
x_annot = xtail + (shaft_x - xtail) / 2
|
| 328 |
+
y_annot = ytail + (shaft_y - ytail) / 2
|
| 329 |
+
|
| 330 |
+
x_annot = xhead - dx * 0.74 * SQUARE_SIZE / hypot # - (xtip - xtail)*(SQUARE_SIZE/2)
|
| 331 |
+
y_annot = yhead - dy * 0.74 * SQUARE_SIZE / hypot # - (ytip - ytail)*(SQUARE_SIZE/2)
|
| 332 |
+
|
| 333 |
+
ET.SubElement(svg, "line", {
|
| 334 |
+
"x1": str(xtail),
|
| 335 |
+
"y1": str(ytail),
|
| 336 |
+
"x2": str(shaft_x),
|
| 337 |
+
"y2": str(shaft_y),
|
| 338 |
+
"stroke": color,
|
| 339 |
+
"stroke-width": str(SQUARE_SIZE * 0.15),
|
| 340 |
+
"opacity": "0.5",
|
| 341 |
+
"stroke-linecap": "butt",
|
| 342 |
+
"class": "arrow",
|
| 343 |
+
})
|
| 344 |
+
|
| 345 |
+
marker = [(xtip, ytip),
|
| 346 |
+
(shaft_x + dy * 0.5 * marker_size / hypot,
|
| 347 |
+
shaft_y - dx * 0.5 * marker_size / hypot),
|
| 348 |
+
(shaft_x - dy * 0.5 * marker_size / hypot,
|
| 349 |
+
shaft_y + dx * 0.5 * marker_size / hypot)]
|
| 350 |
+
|
| 351 |
+
ET.SubElement(svg, "polygon", {
|
| 352 |
+
"points": " ".join(str(x) + "," + str(y) for x, y in marker),
|
| 353 |
+
"fill": color,
|
| 354 |
+
"opacity": "0.5",
|
| 355 |
+
"class": "arrow",
|
| 356 |
+
})
|
| 357 |
+
|
| 358 |
+
for arrow in arrows:
|
| 359 |
+
try:
|
| 360 |
+
tail, head, color, annotation = arrow.tail, arrow.head, arrow.color, arrow.annotation # type: ignore
|
| 361 |
+
except AttributeError:
|
| 362 |
+
tail, head = arrow # type: ignore
|
| 363 |
+
color = "#888"
|
| 364 |
+
annotation = ''
|
| 365 |
+
|
| 366 |
+
tail_file = chess.square_file(tail)
|
| 367 |
+
tail_rank = chess.square_rank(tail)
|
| 368 |
+
head_file = chess.square_file(head)
|
| 369 |
+
head_rank = chess.square_rank(head)
|
| 370 |
+
|
| 371 |
+
xtail = margin + (tail_file + 0.5 if not flipped else 7.5 - tail_file) * SQUARE_SIZE
|
| 372 |
+
ytail = margin + (7.5 - tail_rank if not flipped else tail_rank + 0.5) * SQUARE_SIZE
|
| 373 |
+
xhead = margin + (head_file + 0.5 if not flipped else 7.5 - head_file) * SQUARE_SIZE
|
| 374 |
+
yhead = margin + (7.5 - head_rank if not flipped else head_rank + 0.5) * SQUARE_SIZE
|
| 375 |
+
|
| 376 |
+
marker_size = 0.5 * SQUARE_SIZE
|
| 377 |
+
marker_margin = 0.05 * SQUARE_SIZE
|
| 378 |
+
|
| 379 |
+
dx, dy = xhead - xtail, yhead - ytail
|
| 380 |
+
hypot = math.hypot(dx, dy)
|
| 381 |
+
|
| 382 |
+
shaft_x = xhead - dx * (marker_size + marker_margin) / hypot
|
| 383 |
+
shaft_y = yhead - dy * (marker_size + marker_margin) / hypot
|
| 384 |
+
|
| 385 |
+
xtip = xhead - dx * marker_margin / hypot
|
| 386 |
+
ytip = yhead - dy * marker_margin / hypot
|
| 387 |
+
|
| 388 |
+
x_annot = xhead - dx * 0.74 * SQUARE_SIZE / hypot
|
| 389 |
+
y_annot = yhead - dy * 0.74 * SQUARE_SIZE / hypot
|
| 390 |
+
|
| 391 |
+
if annotation != '':
|
| 392 |
+
ET.SubElement(svg, "circle", {
|
| 393 |
+
"cx": str(x_annot),
|
| 394 |
+
"cy": str(y_annot),
|
| 395 |
+
"r": str(SQUARE_SIZE * 0.175),
|
| 396 |
+
#"r": str(SQUARE_SIZE * 0.2),
|
| 397 |
+
"stroke-width": str(SQUARE_SIZE * 0.01),
|
| 398 |
+
"stroke": '#000000',
|
| 399 |
+
"fill": color,
|
| 400 |
+
"opacity": "1.0",
|
| 401 |
+
"class": "circle",
|
| 402 |
+
})
|
| 403 |
+
#style = get_style("'BundledDejavuSans'", str(SQUARE_SIZE * 0.1))
|
| 404 |
+
annot = ET.SubElement(svg, "text", {
|
| 405 |
+
"x": str(x_annot),
|
| 406 |
+
"y": str(y_annot),
|
| 407 |
+
"font-size": str(SQUARE_SIZE * 0.2), # max(1, int(min(SQUARE_SIZE, SQUARE_SIZE) * 0.3))),
|
| 408 |
+
"text-anchor": "middle",
|
| 409 |
+
"dominant-baseline": "middle"
|
| 410 |
+
#"alignment-baseline": "middle"
|
| 411 |
+
})
|
| 412 |
+
annot.text = annotation
|
| 413 |
+
|
| 414 |
+
return SvgWrapper(ET.tostring(svg).decode("utf-8"))
|
svg_pieces.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from python_chess_customized_svg import piece
|
| 2 |
+
import python_chess_customized_svg as svg
|
| 3 |
+
import chess
|
| 4 |
+
import base64
|
| 5 |
+
|
| 6 |
+
#def get_svg_piece(symbol):
|
| 7 |
+
# img = piece(chess.Piece.from_symbol(symbol))
|
| 8 |
+
# svg_str = str(img)
|
| 9 |
+
# svg_byte = svg_str.encode()
|
| 10 |
+
# encoded = base64.b64encode(svg_byte)
|
| 11 |
+
# svg_piece = 'data:image/svg+xml;base64,{}'.format(encoded.decode())
|
| 12 |
+
# return svg_piece
|
| 13 |
+
|
| 14 |
+
def get_svg_board(board, focused_square_ind, only_pieces):
|
| 15 |
+
if focused_square_ind is not None:
|
| 16 |
+
squares = [focused_square_ind]
|
| 17 |
+
else:
|
| 18 |
+
squares = []
|
| 19 |
+
if board.move_stack:
|
| 20 |
+
print('board stack YES')
|
| 21 |
+
lastmove = board.peek()
|
| 22 |
+
else:
|
| 23 |
+
print('board stack NO')
|
| 24 |
+
lastmove = None
|
| 25 |
+
svg_str = str(svg.board(board, squares=squares, arrows=[], lastmove=lastmove, coordinates=False, only_pieces=only_pieces))
|
| 26 |
+
svg_byte = svg_str.encode()
|
| 27 |
+
encoded = base64.b64encode(svg_byte)
|
| 28 |
+
svg_board = 'data:image/svg+xml;base64,{}'.format(encoded.decode())
|
| 29 |
+
return svg_board
|
| 30 |
+
|
| 31 |
+
#SVG_PIECES = {piece: get_svg_piece(piece) for piece in ('b', 'k', 'n', 'p', 'q', 'r', 'B', 'K', 'N', 'P', 'Q', 'R')}
|
utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from leela_board import LeelaBoard
|
| 2 |
+
import chess
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def flip_move(move):
|
| 7 |
+
from_square = chess.square_mirror(chess.parse_square(move[:2]))
|
| 8 |
+
to_square = chess.square_mirror(chess.parse_square(move[2:4]))
|
| 9 |
+
promotion = move[4:] if len(move) > 4 else ""
|
| 10 |
+
return chess.square_name(from_square) + chess.square_name(to_square) + promotion
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def flip_board(fen, moves):
|
| 14 |
+
temp_board = chess.Board(fen=fen)
|
| 15 |
+
return temp_board.mirror().fen(), [flip_move(move) for move in moves]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Helper functions
|
| 19 |
+
class ChessBoard:
|
| 20 |
+
def __init__(self, fen): # Create new board from fen
|
| 21 |
+
self.board = LeelaBoard(fen=fen)
|
| 22 |
+
self.t = self.__t()
|
| 23 |
+
|
| 24 |
+
def move(self, move): # Move piece on board ("e2e3")
|
| 25 |
+
self.board.push_uci(move)
|
| 26 |
+
self.t = self.__t()
|
| 27 |
+
|
| 28 |
+
def __t(self): # Set board tensor (private method)
|
| 29 |
+
return torch.from_numpy(self.board.lcz_features()).float()
|
| 30 |
+
|
| 31 |
+
def __str__(self): # Prints board state
|
| 32 |
+
return str(self.board)
|
visualization_demo.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import marimo
|
| 2 |
+
|
| 3 |
+
__generated_with = "0.8.22"
|
| 4 |
+
app = marimo.App(width="medium")
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@app.cell
|
| 8 |
+
def __():
|
| 9 |
+
import marimo as mo
|
| 10 |
+
return (mo,)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@app.cell
|
| 14 |
+
def __():
|
| 15 |
+
import pandas as pd
|
| 16 |
+
df = pd.read_csv("our_visualization/datasets/test_set.csv")
|
| 17 |
+
df.head()
|
| 18 |
+
return df, pd
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@app.cell
|
| 22 |
+
def __():
|
| 23 |
+
import pickle
|
| 24 |
+
from utils import ChessBoard
|
| 25 |
+
import onnxruntime as ort
|
| 26 |
+
from leela_board import _idx_to_move_bn, _idx_to_move_wn
|
| 27 |
+
import numpy as np
|
| 28 |
+
from onnx2torch import convert
|
| 29 |
+
import onnx
|
| 30 |
+
import torch
|
| 31 |
+
import os
|
| 32 |
+
|
| 33 |
+
def get_models(root="/Users/sereda/Documents/chessXAI/our_visualization/models"):
|
| 34 |
+
paths = os.listdir(root)
|
| 35 |
+
model_paths = []
|
| 36 |
+
for path in paths:
|
| 37 |
+
if ".onnx" in path: model_paths.append(os.path.join(root, path))
|
| 38 |
+
return model_paths
|
| 39 |
+
|
| 40 |
+
def get_activations_from_model(model_path, pattern, fen):
|
| 41 |
+
# Write hooks for selected model path
|
| 42 |
+
def register_hooks_for_capture(model, pattern):
|
| 43 |
+
activations = {}
|
| 44 |
+
def get_activation(name):
|
| 45 |
+
def hook(module, input, output):
|
| 46 |
+
activations[name] = output.detach().numpy()
|
| 47 |
+
return hook
|
| 48 |
+
|
| 49 |
+
handles = []
|
| 50 |
+
for n, m in model.named_modules():
|
| 51 |
+
if pattern in n:
|
| 52 |
+
handle = m.register_forward_hook(get_activation(n))
|
| 53 |
+
handles.append(handle)
|
| 54 |
+
return activations, handles
|
| 55 |
+
|
| 56 |
+
# Load model and register hooks for it
|
| 57 |
+
model = convert(onnx.load(model_path))
|
| 58 |
+
act, handles = register_hooks_for_capture(model, pattern)
|
| 59 |
+
|
| 60 |
+
# Get fen and pass it through model to generate activations
|
| 61 |
+
board = ChessBoard(fen)
|
| 62 |
+
inputs = board.t
|
| 63 |
+
_, _, _ = model(inputs.unsqueeze(dim=0))
|
| 64 |
+
|
| 65 |
+
# Remove handles
|
| 66 |
+
[h.remove() for h in handles]
|
| 67 |
+
return act
|
| 68 |
+
return (
|
| 69 |
+
ChessBoard,
|
| 70 |
+
convert,
|
| 71 |
+
get_activations_from_model,
|
| 72 |
+
get_models,
|
| 73 |
+
np,
|
| 74 |
+
onnx,
|
| 75 |
+
ort,
|
| 76 |
+
os,
|
| 77 |
+
pickle,
|
| 78 |
+
torch,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@app.cell
|
| 83 |
+
def __(df, mo):
|
| 84 |
+
min_elo, max_elo = df["Rating"].min() // 100 * 100, df["Rating"].max() // 100 * 100
|
| 85 |
+
elo_list = [f"{elo}" for elo in range(min_elo, max_elo + 100, 100)]
|
| 86 |
+
dropdown_elo = mo.ui.dropdown(value = "1000", options=elo_list, label=f"Select rating in range of {min_elo} - {max_elo}")
|
| 87 |
+
dropdown_elo
|
| 88 |
+
return dropdown_elo, elo_list, max_elo, min_elo
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@app.cell
|
| 92 |
+
def __(df, dropdown_elo, mo):
|
| 93 |
+
unique_themes = set()
|
| 94 |
+
df_rated = df[(df["Rating"] >= int(dropdown_elo.value)) & (df["Rating"] <= int(dropdown_elo.value) + 100)]
|
| 95 |
+
for i in range(len(df_rated)):
|
| 96 |
+
themes = df_rated.iloc[i]["Themes"].split(" ")
|
| 97 |
+
for theme in themes: unique_themes.add(theme)
|
| 98 |
+
unique_themes_list = list(unique_themes)
|
| 99 |
+
unique_themes_list.sort()
|
| 100 |
+
|
| 101 |
+
dropdown_themes = mo.ui.dropdown(value=unique_themes_list[0], options=unique_themes_list, label=f"Select puzzle theme")
|
| 102 |
+
dropdown_themes
|
| 103 |
+
return (
|
| 104 |
+
df_rated,
|
| 105 |
+
dropdown_themes,
|
| 106 |
+
i,
|
| 107 |
+
theme,
|
| 108 |
+
themes,
|
| 109 |
+
unique_themes,
|
| 110 |
+
unique_themes_list,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@app.cell
|
| 115 |
+
def __(df_rated, dropdown_themes):
|
| 116 |
+
themes_mask = []
|
| 117 |
+
def _(themes_mask):
|
| 118 |
+
for i in range(len(df_rated)):
|
| 119 |
+
themes_new = df_rated.iloc[i]["Themes"].split(" ")
|
| 120 |
+
if dropdown_themes.value in themes_new: themes_mask.append(i)
|
| 121 |
+
_(themes_mask)
|
| 122 |
+
fens = list(df_rated.iloc[themes_mask]["FEN"])
|
| 123 |
+
df_rated.iloc[themes_mask][["FEN", "Moves", "Themes", "Rating"]]
|
| 124 |
+
return fens, themes_mask
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@app.cell
|
| 128 |
+
def __(fens, mo):
|
| 129 |
+
dropdown_fen = mo.ui.dropdown(value = fens[0], options=fens, label="Select FEN")
|
| 130 |
+
dropdown_fen
|
| 131 |
+
return (dropdown_fen,)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@app.cell
|
| 135 |
+
def __(df_rated, dropdown_fen, mo):
|
| 136 |
+
moves = df_rated[df_rated["FEN"] == dropdown_fen.value]["Moves"].iloc[0].split(" ")
|
| 137 |
+
player_moves = moves[1::2]
|
| 138 |
+
board_moves = []
|
| 139 |
+
def _(board_moves):
|
| 140 |
+
for i in range(len(player_moves)):
|
| 141 |
+
board_moves.append(moves[:2 * i + 1])
|
| 142 |
+
_(board_moves)
|
| 143 |
+
moves_dict = {pm: om for pm, om in zip(player_moves, board_moves)}
|
| 144 |
+
dropdown_moves = mo.ui.dropdown(options=moves_dict, value=player_moves[0], label="Select which player move to look at")
|
| 145 |
+
# print(moves)
|
| 146 |
+
dropdown_moves
|
| 147 |
+
return board_moves, dropdown_moves, moves, moves_dict, player_moves
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@app.cell
|
| 151 |
+
def __(dropdown_moves, mo):
|
| 152 |
+
dropdown_layer = mo.ui.dropdown(value="0", options=[f"{i}" for i in range(15)], label="Select layer (smaller - closer to input)")
|
| 153 |
+
focus_square = mo.ui.text_area(value=dropdown_moves.selected_key[:2], placeholder="Input square to look at (e.g. a1, b8, ...")
|
| 154 |
+
mo.vstack([dropdown_layer, focus_square])
|
| 155 |
+
return dropdown_layer, focus_square
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@app.cell
|
| 159 |
+
def __(ChessBoard, dropdown_fen, dropdown_moves):
|
| 160 |
+
def _():
|
| 161 |
+
board = ChessBoard(dropdown_fen.value)
|
| 162 |
+
for move in dropdown_moves.value:
|
| 163 |
+
print(move)
|
| 164 |
+
# board.move(move)
|
| 165 |
+
return board.board.pc_board.fen()
|
| 166 |
+
FEN = _()
|
| 167 |
+
return (FEN,)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@app.cell
|
| 171 |
+
def __(focus_square):
|
| 172 |
+
import chess
|
| 173 |
+
from global_data import global_data
|
| 174 |
+
|
| 175 |
+
focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
|
| 176 |
+
|
| 177 |
+
def set_plotting_parameters(act, layer_number, fen):
|
| 178 |
+
layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
|
| 179 |
+
print(act.keys())
|
| 180 |
+
global_data.model = 'test'
|
| 181 |
+
global_data.activations = act[layer_key][0, :, ::-1 , :]
|
| 182 |
+
print(global_data.activations.shape)
|
| 183 |
+
global_data.subplot_rows = 8
|
| 184 |
+
global_data.subplot_cols = 4
|
| 185 |
+
global_data.board = chess.Board(fen)
|
| 186 |
+
global_data.show_all_heads = True
|
| 187 |
+
# global_data.selected_head = 1
|
| 188 |
+
global_data.visualization_mode = 'ROW'
|
| 189 |
+
global_data.focused_square_ind = focus_square_ind
|
| 190 |
+
# global_data.heatmap_horizontal_gap = 0.001
|
| 191 |
+
|
| 192 |
+
global_data.visualization_mode_is_64x64 = False
|
| 193 |
+
global_data.colorscale_mode = "mode1"
|
| 194 |
+
global_data.show_colorscale = False
|
| 195 |
+
return chess, focus_square_ind, global_data, set_plotting_parameters
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@app.cell
|
| 199 |
+
def __(
|
| 200 |
+
FEN,
|
| 201 |
+
dropdown_layer,
|
| 202 |
+
get_activations_from_model,
|
| 203 |
+
get_models,
|
| 204 |
+
set_plotting_parameters,
|
| 205 |
+
):
|
| 206 |
+
# FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
| 207 |
+
# board = ChessBoard("r1b2rk1/pp2pp1p/6p1/3Qb2q/1P4n1/2P1BN2/P2N1PPP/R4RK1 w - - 0 14")
|
| 208 |
+
# board.move("f3e5")
|
| 209 |
+
# FEN = board.board.pc_board.fen()
|
| 210 |
+
PATTERN = "mha/QK/softmax"
|
| 211 |
+
# PATTERN = "smolgen_weights"
|
| 212 |
+
MODEL = get_models()[-1]
|
| 213 |
+
ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN)
|
| 214 |
+
set_plotting_parameters(ACTIVATIONS, int(dropdown_layer.value), FEN)
|
| 215 |
+
from activation_heatmap import heatmap_figure
|
| 216 |
+
fig = heatmap_figure()
|
| 217 |
+
fig.update_layout(height=1500, width=1200)
|
| 218 |
+
fig
|
| 219 |
+
return ACTIVATIONS, MODEL, PATTERN, fig, heatmap_figure
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@app.cell
|
| 223 |
+
def __():
|
| 224 |
+
# Add fens after opponents moves
|
| 225 |
+
# Default squares of interest
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
app.run()
|