seredapj commited on
Commit
0c536cc
·
verified ·
1 Parent(s): ec03afc

Upload 19 files

Browse files
.gitattributes CHANGED
@@ -36,3 +36,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  our_visualization/assets/DejaVuSans-Bold.ttf filter=lfs diff=lfs merge=lfs -text
37
  our_visualization/assets/DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text
38
  our_visualization/assets/DejavuSansMono-5m7L.ttf filter=lfs diff=lfs merge=lfs -text
 
 
 
 
36
  our_visualization/assets/DejaVuSans-Bold.ttf filter=lfs diff=lfs merge=lfs -text
37
  our_visualization/assets/DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text
38
  our_visualization/assets/DejavuSansMono-5m7L.ttf filter=lfs diff=lfs merge=lfs -text
39
+ assets/DejaVuSans-Bold.ttf filter=lfs diff=lfs merge=lfs -text
40
+ assets/DejaVuSans.ttf filter=lfs diff=lfs merge=lfs -text
41
+ assets/DejavuSansMono-5m7L.ttf filter=lfs diff=lfs merge=lfs -text
activation_heatmap.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chess
2
+ import dash
3
+ import plotly.graph_objs as go
4
+ from plotly.subplots import make_subplots
5
+ from global_data import global_data
6
+
7
+ from svg_pieces import get_svg_board
8
+ from dash import dcc, html, Input, Output, State
9
+
10
+
11
+ import time
12
+
13
+ import numpy as np
14
+
15
+ from plotly.io import to_json
16
+
17
+ import pickle
18
+
19
+ V_GAP = 0.15
20
+ LAYOUT_MARGIN_V = 40
21
+
22
+ def heatmap_data(head):
23
+ data = global_data.get_head_data(head)
24
+ return data
25
+
26
+
27
+ def heatmap_figure():
28
+ if global_data.model is None:
29
+ return {}
30
+ start = time.time()
31
+ fig = make_figure()
32
+ print('make fig:', time.time() - start)
33
+
34
+ start = time.time()
35
+ fig = add_heatmap_traces(fig)
36
+ print('add traces:', time.time() - start)
37
+ with open("./test_activations_starting.pkl", 'wb') as f:
38
+ print("saving activations")
39
+ pickle.dump(global_data.activations, f)
40
+ start = time.time()
41
+ fig = add_layout(fig)
42
+ print('add layout total:', time.time() - start)
43
+
44
+ start = time.time()
45
+
46
+ if global_data.selected_layer == 'Smolgen':
47
+ with open('fig_as_json_no_pieces.json', 'w') as f:
48
+ f.write(to_json(fig, pretty=True))
49
+
50
+ if not global_data.visualization_mode_is_64x64:# and global_data.selected_layer != 'Smolgen':
51
+ fig = add_pieces(fig)
52
+ print('add pieces:', time.time() - start)
53
+
54
+ if global_data.selected_layer == 'Smolgen':
55
+ with open('fig_as_json.json', 'w') as f:
56
+ f.write(to_json(fig, pretty=True))
57
+
58
+ return fig
59
+
60
+
61
+ def heatmap():
62
+ start = time.time()
63
+ # We need to recalculate graph when grid size changes, other wise layout is a mess (Dash bug?). Use hidden Div's children to trigger callback for graph recalc.
64
+ # Otherwise, we can just recalculate figure part and frontend rendering will be much faster
65
+ #
66
+ graph = html.Div(id='graph-container', children=[heatmap_graph()],
67
+ style={'height': '100%', 'width': '100%', "overflow": "auto"#, 'textAlign': 'center'#, "display": "flex", "justifyContent":"center"
68
+ })
69
+ print('GRAPH CREATION:', time.time() - start)
70
+ return graph
71
+
72
+
73
+ def heatmap_graph():
74
+ fig = heatmap_figure()
75
+
76
+ config = {
77
+ 'displaylogo': False,
78
+ 'displayModeBar': True,
79
+ 'modeBarButtonsToRemove': ['zoom', 'pan', 'select', 'zoomIn', 'zoomOut', 'autoScale', 'resetScale'],
80
+ 'toImageButtonOptions': {
81
+ 'format': global_data.export_format,
82
+ 'scale': global_data.export_scale
83
+ }}
84
+
85
+ style = {'height': global_data.figure_container_height, 'width': '100%'}#, 'margin': '0 auto'}
86
+
87
+ graph = dcc.Graph(figure=fig, id='graph', style=style,
88
+ responsive='auto',#True, # True,
89
+ config=config
90
+ )
91
+
92
+ # graph = html.Div(id='graph-container', children=[graph], style={'height': '100%', 'width': '100%', "overflow": "auto"
93
+ # })
94
+ # graph_component.children = [graph]
95
+
96
+ global_data.cache_figure(fig)
97
+
98
+ return graph
99
+
100
+
101
+ def make_figure():
102
+ #print('assumed key', global_data.subplot_rows, global_data.subplot_cols, global_data.visualization_mode_is_64x64, global_data.selected_head if not global_data.show_all_heads else -1)
103
+ #print('key', global_data.get_figure_cache_key())
104
+ #print('all keys', global_data.figure_cache.keys())
105
+ fig = global_data.get_cached_figure()
106
+ if fig is None:
107
+ if global_data.show_all_heads:
108
+ titles = [f"Head {i + 1}" for i in range(global_data.number_of_heads)]
109
+ print('MAKING SUBPLOTS', 'rows:', global_data.subplot_rows, 'cols:', global_data.subplot_cols)
110
+ print('NUMBER OF HEADS:', global_data.number_of_heads)
111
+ fig = make_subplots(rows=global_data.subplot_rows, cols=global_data.subplot_cols, subplot_titles=titles,
112
+ horizontal_spacing=global_data.heatmap_horizontal_gap / global_data.subplot_cols,
113
+ vertical_spacing=V_GAP / global_data.subplot_rows,
114
+ )
115
+ else:
116
+ print('CREATING 1x1')
117
+ titles = [f"head {global_data.selected_head +1}"]
118
+ fig = make_subplots(rows=1, cols=1, subplot_titles=titles)#go.Figure()#make_subplots(rows=1, cols=1, subplot_titles=titles)
119
+
120
+ return fig
121
+
122
+
123
+ def add_layout(fig):
124
+ start = time.time()
125
+
126
+ #coloraxis1 = None
127
+ #if global_data.visualization_mode_is_64x64:
128
+ # if global_data.colorscale_mode == '3':
129
+ # coloraxis1 = {'colorscale': 'Viridis'}
130
+
131
+ coloraxis = None
132
+ if global_data.colorscale_mode == 'mode3':
133
+ cmin = np.amin(global_data.activations[:, :, :])
134
+ cmax = np.amax(global_data.activations[:, :, :])
135
+ coloraxis = {'colorscale': 'Viridis', 'colorbar': {'ypad': 0} , 'cmin': cmin, 'cmax': cmax, 'showscale': global_data.show_colorscale}
136
+
137
+ if global_data.check_if_figure_is_cached():
138
+ print('Using existing layout')
139
+ fig.update_layout({'coloraxis1': coloraxis}, overwrite=True)
140
+ return fig
141
+
142
+ layout = go.Layout(
143
+ # title='Plot title goes here',
144
+ margin={'t': LAYOUT_MARGIN_V, 'b': LAYOUT_MARGIN_V, 'r': 40, 'l': 40},
145
+ coloraxis1=coloraxis,
146
+ modebar={'orientation': 'v'}
147
+ #coloraxis={'colorscale': 'Viridis'}
148
+ #pa
149
+ #plot_bgcolor='rgb(0,0,0)',
150
+ #paper_bgcolor="black"
151
+ )
152
+
153
+ fig.update_layout(layout)
154
+ # fig['layout'].update(layout)
155
+
156
+ print('update layout:', time.time() - start)
157
+
158
+ start = time.time()
159
+ fig = update_axis(fig)
160
+ print('update axis:', time.time() - start)
161
+ # print(fig)
162
+ return fig
163
+
164
+
165
+ def update_axis(fig):
166
+ if global_data.visualization_mode_is_64x64:
167
+ tickvals_x = list(range(0, 64, 4))
168
+ tickvals_y = list(range(3, 67, 4))#list(range(0, 64, 4))#list(range(3, 67, 4))
169
+ if global_data.board.turn or global_data.selected_layer == 'Smolgen':
170
+ ticktext_x = [x + y for x, y in zip('ae' * 8, '1122334455667788')]
171
+ #tickvals = list(range(0, 64))
172
+ #ticktext_x = [x + y for x, y in zip('abcdefg' * 8, '1'*8 + '2'*8 + '3'*8 + '4'*8 + '5'*8 + '6'*8 + '7'*8 + '8'*8)]
173
+ ticktext_y = ticktext_x[::-1]
174
+ else:
175
+ ticktext_x = [x + y for x, y in zip('ae' * 8, '1122334455667788'[::-1])]
176
+ ticktext_y = ticktext_x[::-1]
177
+ showticklabels = True
178
+ #ticklabelstep = 4
179
+ val_range = [-0.5, 63.5]
180
+ ticks = 'outside'
181
+ title_x = {'text': "Keys ('to' square)", 'standoff': 1}
182
+ title_y = {'text': "Queries ('from' square)", 'standoff': 1}
183
+ else:
184
+ title_x = None
185
+ title_y = None
186
+ tickvals_x = list(range(8)) # [0, 1, 2, 3, 4, 5, 6, 7]
187
+ tickvals_y = tickvals_x
188
+ ticktext_x = [letter for letter in 'abcdefgh']
189
+ ticktext_y = [letter for letter in '12345678']
190
+ showticklabels = True
191
+ #ticklabelstep = 1
192
+ val_range = [-0.5, 7.5]
193
+ ticks = ''
194
+
195
+ if not global_data.show_all_heads or (global_data.subplot_cols == 1 and global_data.subplot_rows == 1):
196
+ constraintowards_x = 'center'
197
+ else:
198
+ constraintowards_x = 'right'
199
+
200
+
201
+ fig.update_xaxes(title=title_x,
202
+ range=val_range,
203
+ # ticklen=50,
204
+ zeroline=False,
205
+ showgrid=False,
206
+ scaleanchor='y',
207
+ constrain='domain',
208
+ constraintoward=constraintowards_x,
209
+ ticks=ticks, # ticks,
210
+ ticktext=ticktext_x,
211
+ tickvals=tickvals_x,
212
+ showticklabels=showticklabels,
213
+ # mirror='ticks',
214
+ fixedrange=True,
215
+ #ticklabelstep=ticklabelstep,
216
+ )
217
+
218
+ fig.update_yaxes(title=title_y,
219
+ range=val_range,
220
+ zeroline=False,
221
+ showgrid=False,
222
+ scaleanchor='x',
223
+ constrain='domain',
224
+ constraintoward='top',
225
+ ticks=ticks, # ticks,
226
+ ticktext=ticktext_y,
227
+ tickvals=tickvals_y,
228
+ showticklabels=showticklabels,
229
+ # mirror='allticks',
230
+ # side='top',
231
+ fixedrange=True,
232
+ #ticklabelstep=ticklabelstep
233
+ )
234
+ return fig
235
+
236
+
237
+ def calc_colorbar(row, col):
238
+ row = global_data.subplot_rows - row + 1 #invert
239
+ #row = global_data.subplot_rows - row - 1 #invert
240
+
241
+ dy = (1/global_data.subplot_rows)
242
+ dx = (1/global_data.subplot_cols)
243
+
244
+ offset = 1/global_data.subplot_cols - 2*(global_data.heatmap_horizontal_gap/(global_data.subplot_cols))/4#global_data.colorscale_x_offset#(494.1125)/2239.2#1/global_data.subplot_cols - 3*(global_data.heatmap_horizontal_gap/(global_data.subplot_cols))/4
245
+
246
+ if global_data.heatmap_h == 0:
247
+ len = (1 - V_GAP/global_data.subplot_rows) / global_data.subplot_rows #- #V_GAP/global_data.subplot_rows
248
+ lenmode = 'fraction'
249
+ offset2 = len / 2
250
+ else:
251
+ #total_h = global_data.heatmap_fig_h * global_data.heatmap_h + (global_data.subplot_rows - 1)
252
+ len = global_data.heatmap_h/(global_data.heatmap_fig_h - 2*LAYOUT_MARGIN_V)
253
+ lenmode = 'fraction'
254
+ #offset2 = len / 2
255
+ #lenmode = 'pixels'
256
+ #offset2 = 1 - len/(global_data.subplot_rows*len + V_GAP) #1/global_data.subplot_rows - (V_GAP/global_data.subplot_rows)
257
+
258
+ #tot_h = global_data.heatmap_fig_h - 2*LAYOUT_MARGIN_V
259
+ #max_h = ((1 - V_GAP)) / global_data.subplot_rows
260
+ #cur_h = len
261
+ offset2 = 1 - (global_data.subplot_rows-1)*(dy + (V_GAP / global_data.subplot_rows)/global_data.subplot_rows) - len/2 #0#len/2 #+ (max_h - cur_h)
262
+
263
+
264
+ #offset = global_data.colorscale_x_offset
265
+ #shift = (global_data.heatmap_w + 20 + 20 + global_data.heatmap_gap)/global_data.heatmap_fig_w
266
+ #cx = (col - 1) * shift + offset
267
+
268
+
269
+ cx = (col-1)*(dx + (global_data.heatmap_horizontal_gap / global_data.subplot_cols)/global_data.subplot_cols) + offset
270
+ cy = (row-1)*(dy + (V_GAP / global_data.subplot_rows)/global_data.subplot_rows) + offset2
271
+ #cy = (global_data.subplot_rows - 1 - row) * (dy + (V_GAP / global_data.subplot_rows) / global_data.subplot_rows) + offset2
272
+
273
+ #######################
274
+
275
+
276
+ colorbar=dict(len=len, y=cy, x=cx, ypad=0, xpad=0, ticklabelposition='inside', ticks='inside', ticklen=3, lenmode=lenmode,
277
+ tickfont=dict(color='#7e807f'))
278
+
279
+ return colorbar
280
+
281
+ def add_heatmap_trace(fig, row, col, head=None):
282
+ # print('ADDING heatmap', row, col)
283
+ if head is None:
284
+ head = (row - 1) * global_data.subplot_cols + (col - 1)
285
+ data = heatmap_data(head)
286
+
287
+ if data is None:
288
+ return fig
289
+
290
+ if global_data.visualization_mode_is_64x64:
291
+ xgap, ygap = 0, 0
292
+ #hovertemplate = 'Query: <b>%{y}</b> <br> Key: <b>%{x}</b> <br> value: <b>%{z}</b><extra></extra>'
293
+ hovertemplate = 'Query: <b>%{customdata[0]}</b> <br>Key: <b>%{customdata[1]}</b> <br>value: <b>%{z:.5f}</b><extra></extra>'
294
+ if global_data.board.turn or global_data.selected_layer == 'Smolgen':
295
+ customdata_x = [[letter + ind for ind in '12345678' for letter in 'abcdefgh'] for _ in range(64)]
296
+ customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678'[::-1] for letter in 'abcdefgh'[::-1]]
297
+ else:
298
+ customdata_x = [[letter + ind for ind in '12345678'[::-1] for letter in 'abcdefgh'] for _ in range(64)]
299
+ customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678' for letter in 'abcdefgh'[::-1]]
300
+
301
+ #customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678' for letter in 'abcdefgh']
302
+ customdata = np.moveaxis([customdata_y, customdata_x], 0, -1)#[customdata_x, customdata_y]
303
+
304
+ else:
305
+ xgap, ygap = 2, 2
306
+ hovertemplate = '<b>%{x}%{y}</b>: <b>%{z}</b><extra></extra>'
307
+ customdata = None
308
+
309
+ coloraxis = None
310
+ #Colorscale
311
+
312
+ #if global_data.visualization_mode_is_64x64:
313
+ # if global_data.colorscale_mode == '3':
314
+ # coloraxis = 'coloraxis1'
315
+
316
+ coloraxis = None
317
+ colorscale = 'Viridis'
318
+ colorbar = None
319
+ #if global_data.show_colorscale and global_data.colorscale_mode == 'mode3':
320
+ if global_data.colorscale_mode == 'mode3':
321
+ coloraxis = 'coloraxis1'
322
+ colorscale = None
323
+
324
+ elif global_data.show_colorscale and not global_data.colorscale_mode == 'mode3' and not (not global_data.show_all_heads or (global_data.subplot_cols == 1 and global_data.subplot_rows == 1)):
325
+ colorbar = calc_colorbar(row, col)
326
+
327
+ zmin, zmax = None, None
328
+
329
+ if global_data.colorscale_mode == 'mode2':
330
+ pass
331
+ zmin = np.amin(global_data.activations[head, :, :])
332
+ zmax = np.amax(global_data.activations[head, :, :])
333
+ #print('ZMINMAX M2', head, zmin, zmax)
334
+
335
+ elif global_data.colorscale_mode == 'mode1':
336
+ zmin = np.amin(data)
337
+ zmax = np.amax(data)
338
+
339
+ #print('Trace data shape', data.shape)
340
+ trace = go.Heatmap(
341
+ z=data,
342
+ colorscale=colorscale,
343
+ showscale=global_data.show_colorscale,#True,
344
+ colorbar=colorbar,
345
+ #colorbar=dict(len=len, y=cy, x=cx, ypad=0, ticklabelposition='inside', ticks='inside', ticklen=3,
346
+ # tickfont=dict(color='#7e807f')),
347
+ xgap=xgap,
348
+ ygap=ygap,
349
+ hovertemplate=hovertemplate,
350
+ customdata=customdata,
351
+ zmin=zmin,
352
+ zmax=zmax,
353
+ coloraxis=coloraxis
354
+ #zmin=zmin,
355
+ #zmax=zmax
356
+ )
357
+ fig.add_trace(trace, row=row, col=col)
358
+ return fig
359
+
360
+
361
+ def add_heatmap_traces(fig):
362
+ print('adding traces, rows:', global_data.subplot_rows, 'cols:', global_data.subplot_cols)
363
+ #adding traces is quick so we don't bother using cached values. Wipe old traces and add new.
364
+ fig.data = []
365
+ if global_data.show_all_heads:
366
+ for row in range(global_data.subplot_rows):
367
+ for col in range(global_data.subplot_cols):
368
+ fig = add_heatmap_trace(fig, row + 1, col + 1)
369
+ else:
370
+ fig = add_heatmap_trace(fig, 1, 1, global_data.selected_head)
371
+ return fig
372
+
373
+
374
+ def add_pieces(fig):
375
+ if global_data.selected_layer != 'Smolgen':
376
+ board = global_data.board
377
+
378
+ else:
379
+ board = chess.Board(fen=None) #Empty board, we want to draw only the focused square
380
+ board_svg = get_svg_board(board, global_data.focused_square_ind, True)
381
+
382
+ images = [dict(
383
+ source=board_svg,
384
+ xref="x"+str(i),
385
+ yref="y"+str(i),
386
+ x=3.5,
387
+ y=3.5,
388
+ sizex=8,
389
+ sizey=8,
390
+ xanchor='center',
391
+ yanchor='middle',
392
+ sizing="stretch",
393
+ )
394
+ for i in range(2, 2+255)
395
+ ]
396
+ images = [dict(
397
+ source=board_svg,
398
+ xref="x",
399
+ yref="y",
400
+ x=3.5,
401
+ y=3.5,
402
+ sizex=8,
403
+ sizey=8,
404
+ xanchor='center',
405
+ yanchor='middle',
406
+ sizing="stretch",
407
+ )] + images
408
+
409
+ fig.layout.images = images
410
+ return fig
411
+ board_svg = get_svg_board(board, global_data.focused_square_ind, True)
412
+ if global_data.check_if_figure_is_cached():
413
+ print('USING CACHED')
414
+ for img in fig.layout.images:
415
+ img['source'] = board_svg
416
+ else:
417
+ fig.add_layout_image(
418
+ dict(
419
+ source=board_svg,
420
+ xref="x",
421
+ yref="y",
422
+ x=3.5,
423
+ y=3.5,
424
+ sizex=8,
425
+ sizey=8,
426
+ xanchor='center',
427
+ yanchor='middle',
428
+ sizing="stretch",
429
+ ),
430
+ row='all',
431
+ col='all',
432
+ exclude_empty_subplots=True,
433
+ )
434
+
435
+ return fig
436
+
assets/DejaVuSans-Bold.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6476c1b80502924294eed40894c5b18e06c181444ca953e5334262df9c27724
3
+ size 705684
assets/DejaVuSans.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7da195a74c55bef988d0d48f9508bd5d849425c1770dba5d7bfc6ce9ed848954
3
+ size 757076
assets/DejavuSansMono-5m7L.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfbac4c793ca4b896c34af5c7ccbbaf37924b46cd521a5ecff9130b9c331f575
3
+ size 333636
assets/custom.css ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *[class="model-loading"][data-dash-is-loading="true"]::before{
2
+ content: "Loading model...";
3
+ display: inline-block;
4
+ color: red;
5
+ visibility: visible;
6
+ font-size: 16px;
7
+ /*transition-delay: 1s;
8
+ transition-property: font-size;
9
+ transition-duration: 200ms;*/
10
+ }
11
+
12
+ .completely-hidden{
13
+ display: none;
14
+ }
15
+
16
+ .hidden-but-reserve-space{
17
+ visibility: hidden;
18
+ }
19
+
20
+ #link{
21
+ text-decoration: underline;
22
+ cursor: pointer;
23
+ }
24
+
25
+ .header-container {
26
+ border-bottom: thin lightgrey solid;
27
+ box-sizing: border-box;
28
+ white-space: nowrap;
29
+ overflow-y: auto;
30
+ }
31
+
32
+ .header-control-container {
33
+ margin-left: 20px;
34
+ margin-right: 20px;
35
+ }
36
+
37
+ body {
38
+ margin: 0;
39
+ padding: 0;
40
+ font-family: 'BundledDejavuSans';
41
+ font-size: 14px;
42
+ -moz-user-select: none;
43
+ }
44
+
45
+ @font-face {
46
+ font-family: 'BundledDejavuSansMono';
47
+ src: url('DejavuSansMono-5m7L.ttf');
48
+ }
49
+
50
+ @font-face {
51
+ font-family: 'BundledDejavuSans';
52
+ src: url('DejaVuSans.ttf');
53
+ }
54
+
55
+ @font-face {
56
+ font-family: 'BundledDejavuSans';
57
+ src: url('DejaVuSans-Bold.ttf');
58
+ font-weight: bold;
59
+ }
assets/favicon.ico ADDED
board2planes.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Author: https://github.com/Arcturai
2
+
3
+ import chess
4
+ import numpy as np
5
+ import re
6
+
7
+
8
+ WPAWN = chess.Piece(chess.PAWN, chess.WHITE)
9
+ WKNIGHT = chess.Piece(chess.KNIGHT, chess.WHITE)
10
+ WBISHOP = chess.Piece(chess.BISHOP, chess.WHITE)
11
+ WROOK = chess.Piece(chess.ROOK, chess.WHITE)
12
+ WQUEEN = chess.Piece(chess.QUEEN, chess.WHITE)
13
+ WKING = chess.Piece(chess.KING, chess.WHITE)
14
+ BPAWN = chess.Piece(chess.PAWN, chess.BLACK)
15
+ BKNIGHT = chess.Piece(chess.KNIGHT, chess.BLACK)
16
+ BBISHOP = chess.Piece(chess.BISHOP, chess.BLACK)
17
+ BROOK = chess.Piece(chess.ROOK, chess.BLACK)
18
+ BQUEEN = chess.Piece(chess.QUEEN, chess.BLACK)
19
+ BKING = chess.Piece(chess.KING, chess.BLACK)
20
+
21
+
22
+ def assign_piece2(planes, piece_step, row, col):
23
+ planes[piece_step][row][col] = 1
24
+
25
+
26
+ DISPATCH2 = {}
27
+
28
+ DISPATCH2[str(WPAWN)] = lambda retval, row, col: assign_piece2(retval, 0, row, col)
29
+ DISPATCH2[str(WKNIGHT)] = lambda retval, row, col: assign_piece2(retval, 1, row, col)
30
+ DISPATCH2[str(WBISHOP)] = lambda retval, row, col: assign_piece2(retval, 2, row, col)
31
+ DISPATCH2[str(WROOK)] = lambda retval, row, col: assign_piece2(retval, 3, row, col)
32
+ DISPATCH2[str(WQUEEN)] = lambda retval, row, col: assign_piece2(retval, 4, row, col)
33
+ DISPATCH2[str(WKING)] = lambda retval, row, col: assign_piece2(retval, 5, row, col)
34
+ DISPATCH2[str(BPAWN)] = lambda retval, row, col: assign_piece2(retval, 6, row, col)
35
+ DISPATCH2[str(BKNIGHT)] = lambda retval, row, col: assign_piece2(retval, 7, row, col)
36
+ DISPATCH2[str(BBISHOP)] = lambda retval, row, col: assign_piece2(retval, 8, row, col)
37
+ DISPATCH2[str(BROOK)] = lambda retval, row, col: assign_piece2(retval, 9, row, col)
38
+ DISPATCH2[str(BQUEEN)] = lambda retval, row, col: assign_piece2(retval, 10, row, col)
39
+ DISPATCH2[str(BKING)] = lambda retval, row, col: assign_piece2(retval, 11, row, col)
40
+
41
+
42
+ def append_plane(planes, ones):
43
+ if ones:
44
+ return np.append(planes, np.ones((1, 8, 8), dtype=np.float), axis=0)
45
+ else:
46
+ return np.append(planes, np.zeros((1, 8, 8), dtype=np.float), axis=0)
47
+
48
+
49
+ def fill_planes(board):
50
+ planes = np.zeros((12, 8, 8), dtype=np.float)
51
+ for row in range(8):
52
+ for col in range(8):
53
+ piece = str(board.piece_at(chess.SQUARES[row * 8 + col]))
54
+ if piece != "None":
55
+ DISPATCH2[piece](planes, row, col)
56
+ planes = append_plane(planes, board.is_repetition(2))
57
+ return planes
58
+
59
+
60
+ def board2planes(board_):
61
+ if not board_.turn:
62
+ board = board_.mirror()
63
+ else:
64
+ board = board_
65
+
66
+ retval = fill_planes(board)
67
+
68
+ s_board = board_.copy()
69
+ for i in range(7):
70
+ if s_board.move_stack.__len__() > 0:
71
+ s_board.pop()
72
+ b = s_board.mirror() if not board_.turn else s_board.copy()
73
+ retval = np.append(retval, fill_planes(b), axis=0)
74
+ else:
75
+ retval = np.append(retval, np.zeros((13, 8, 8), dtype=np.float), axis=0)
76
+
77
+ retval = append_plane(retval, bool(board.castling_rights & chess.BB_H1))
78
+ retval = append_plane(retval, bool(board.castling_rights & chess.BB_A1))
79
+ retval = append_plane(retval, bool(board.castling_rights & chess.BB_H8))
80
+ retval = append_plane(retval, bool(board.castling_rights & chess.BB_A8))
81
+ retval = append_plane(retval, not board_.turn)
82
+ retval = np.append(retval, np.full((1, 8, 8), fill_value=board_.halfmove_clock/99., dtype=np.float), axis=0)
83
+ retval = append_plane(retval, False)
84
+ retval = append_plane(retval, True)
85
+
86
+ return np.expand_dims(retval, axis=0)
87
+
88
+
89
+ def bulk_board2planes(boards):
90
+ planes = []
91
+ for b in boards:
92
+ temp = board2planes(b)
93
+ planes.append(temp)
94
+ pl = tuple(planes)
95
+ retval = np.concatenate(pl, axis=0)
96
+ return retval
97
+
constants.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+
5
+ def root_directory():
6
+ if getattr(sys, 'frozen', False):
7
+ # The application is frozen
8
+ root = os.path.dirname(sys.executable)
9
+ else:
10
+ # The application is not frozen
11
+ root = os.path.dirname(__file__) # os.path.dirname(os.path.abspath(__file__))#os.path.dirname(__file__)
12
+ return (root)
13
+
14
+
15
+ ROOT_DIR = root_directory()
16
+
17
+ LEFT_PANE_WIDTH = 90
18
+ RIGHT_PANE_WIDTH = 100 - LEFT_PANE_WIDTH
19
+ GRAPH_PANE_HEIGHT = 100
20
+ HEADER_HEIGHT = 11
21
+ CONTENT_HEIGHT = 100 - HEADER_HEIGHT
22
+
23
+ EXPORT_FORMAT = 'png' #one of png, svg, jpeg, webp
24
+ EXPORT_SCALE = 1.0 #When 1.0, the figure is exported as same size as currently in the browser. Use e.g. 0.5 to scale to half.
datasets/test_set.csv ADDED
The diff for this file is too large to render. See raw diff
 
global_data.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chess.engine
2
+ from constants import ROOT_DIR, CONTENT_HEIGHT, LEFT_PANE_WIDTH, EXPORT_FORMAT, EXPORT_SCALE
3
+ from time import sleep
4
+ # from test_array import activations_array
5
+
6
+ from copy import deepcopy
7
+
8
+ from board2planes import board2planes
9
+
10
+ import yaml
11
+
12
+ import os
13
+ from os.path import isdir, join
14
+ import sys
15
+
16
+ import numpy as np
17
+
18
+
19
+ SIMULATE_TF = False #TODO: Remove this option, deprecated
20
+ # turn off tensorflow importing and generate random data to speed up development
21
+ DEV_MODE = False
22
+ SIMULATED_LAYERS = 6
23
+ SIMULATED_HEADS = 64
24
+ FIXED_ROW = None # 1 #None to disable
25
+ FIXED_COL = None # 5 #None to disable
26
+ if DEV_MODE:
27
+ class DummyModel:
28
+ def __init__(self, layers, heads):
29
+ self.layers = layers
30
+ self.heads = heads
31
+
32
+ def __call__(self, *args, **kwargs):
33
+ data = [np.random.rand(1, self.heads, 64, 64) for i in range(self.layers)]
34
+ return [None, None, None, data]
35
+
36
+ else:
37
+ import tensorflow as tf
38
+ from tensorflow.compat.v1 import ConfigProto
39
+ from tensorflow.compat.v1 import InteractiveSession
40
+
41
+
42
+ # class to hold data, state and configurations
43
+ # Dash is stateless and in general it is very bad idea to store data in global variables on server side
44
+ # However, this application is ment to be run by single user on local machine, so it is safe to store data and state
45
+ # information on global object
46
+ class GlobalData:
47
+ def __init__(self):
48
+ import os
49
+ if not DEV_MODE:
50
+ # import tensorflow as tf
51
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
52
+
53
+ # from tensorflow.compat.v1 import ConfigProto
54
+ # from tensorflow.compat.v1 import InteractiveSession
55
+ # import chess
56
+ # import matplotlib.patheffects as path_effects
57
+
58
+ #config = ConfigProto()
59
+ #config.gpu_options.allow_growth = True
60
+ #session = InteractiveSession(config=config)
61
+ #tf.keras.backend.clear_session()
62
+
63
+ self.tmp = 0
64
+ self.export_format = EXPORT_FORMAT
65
+ self.export_scale = EXPORT_SCALE
66
+ self.fen = 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1' # '2kr3r/ppp2b2/2n4p/4p3/Q2Pq1pP/2P1N3/PP3PP1/R1B1KB1R w KQ - 3 18'#'6n1/1p1k4/3p4/pNp5/P1P4p/7P/1P4KP/r7 w - - 2 121'#
67
+ self.board = chess.Board(fen=self.fen)
68
+ self.focused_square_ind = 0
69
+ self.active_move_table_cell = None # tuple (row_ind, col_id), e.g. (12, 'White')
70
+
71
+ self.activations = None # activations_array
72
+ self.visualization_mode = 'ROW'
73
+ self.visualization_mode_is_64x64 = False
74
+ self.subplot_mode = 'big' #'fit' # big'#'fit'#, 'big'
75
+ self.subplot_cols = 0
76
+ self.subplot_rows = 0
77
+ self.number_of_heads = 0
78
+ self.selected_head = None
79
+ self.show_all_heads = True
80
+
81
+ self.show_colorscale = False
82
+ self.colorscale_mode = 'mode1'
83
+
84
+ self.figure_container_height = '100%' # '100%'
85
+
86
+ self.running_counter = 0 # used to pass new values to hidden indicator elements which will trigger follow-up callback
87
+ self.grid_has_changed = False
88
+
89
+ # self.has_subplot_grid_changed = True
90
+ # self.figure_layout_images = None #store layout and only recalculate when subplot grid has changed
91
+ # self.figure_layout_annotations = None
92
+ # self.need_update_axis = True
93
+
94
+ self.screen_w = 0
95
+ self.screen_h = 0
96
+ self.figure_w = 0
97
+ self.figure_h = 0
98
+ self.heatmap_w = 0
99
+ self.heatmap_h = 0
100
+ self.heatmap_fig_w = 0
101
+ self.heatmap_fig_h = 0
102
+ self.heatmap_gap = 0
103
+ self.colorscale_x_offset = 0
104
+
105
+ self.heatmap_horizontal_gap = 0.275
106
+
107
+ self.figure_cache = {}
108
+
109
+ self.update_grid_shape()
110
+
111
+ self.pgn_data = [] # list of boards in pgn
112
+ self.move_table_boards = {} # dict of boards in pgn, key is (move_table.row_ind, move_table.column_id)
113
+
114
+ if not SIMULATE_TF:
115
+ self.selected_layer = None
116
+ else:
117
+ self.selected_layer = 0
118
+
119
+ self.nr_of_layers_in_body = -1
120
+ self.has_attention_policy = False
121
+
122
+ self.model_paths = []
123
+ self.model_names = []
124
+ self.model_yamls = {} #key = model path, value = yaml of that model
125
+ self.model_cache = {}
126
+ self.find_models2()
127
+ self.model_path = None#self.model_paths[0] # '/home/jusufe/PycharmProjects/lc0-attention-visualizer/T12_saved_model_1M'
128
+ self.model = None
129
+ self.tfp = None #TensorflowProcess
130
+ if not SIMULATE_TF:
131
+ self.load_model()
132
+ self.activations_data = None
133
+
134
+ if self.model is not None or SIMULATE_TF:
135
+ self.update_activations_data()
136
+
137
+ if self.selected_layer is not None:
138
+ self.set_layer(self.selected_layer)
139
+
140
+ self.move_table_active_cell = None
141
+
142
+ self.force_update_graph = False
143
+
144
+ def set_subplot_mode(self, fit_to_page):
145
+ if fit_to_page == [True]:
146
+ self.subplot_mode = 'fit'
147
+ else:
148
+ self.subplot_mode = 'big'
149
+ self.update_grid_shape()
150
+
151
+ def set_screen_size(self, w, h):
152
+ self.screen_w = w
153
+ self.screen_h = h
154
+
155
+ self.figure_w = w*LEFT_PANE_WIDTH/100
156
+ self.figure_h = h*CONTENT_HEIGHT/100
157
+ print('GRAPH AREA', self.figure_w, self.figure_h)
158
+
159
+ def set_heatmap_size(self, size):
160
+ if size != '1':
161
+ #print('-----------------------HEATMAP SIZE', size)
162
+ # w, h = size
163
+ # print('TYETETETETEU', global_data.screen_w)
164
+ # global_data.set_screen_size(w, h)
165
+ #print('>>>>>: HEATMAP WIDTH', size[0])
166
+ #print('>>>>>: HEATMAP HEIGHT', size[1])
167
+ #print('>>>>>: FIG WIDTH', size[2])
168
+ #print('>>>>>: FIG HEIGHT', size[3])
169
+ #print('>>>>>: HEATMAP GAP', size[4])
170
+
171
+ self.heatmap_w = float(size[0])
172
+ self.heatmap_h = float(size[1])
173
+ self.heatmap_fig_w = float(size[2])
174
+ self.heatmap_fig_h = float(size[3])
175
+ self.heatmap_gap = round(float(size[4]), 2)
176
+
177
+ self.colorscale_x_offset = float(size[5])/self.heatmap_fig_w
178
+
179
+ if size[6] == 1:
180
+ self.force_update_graph = True
181
+ else:
182
+ self.force_update_graph = False
183
+
184
+
185
+ #if self.heatmap_gap < 30:
186
+ # self.heatmap_horizontal_gap += 0.025
187
+
188
+ # self.heatmap_horizontal_gap = min(0.25, self.heatmap_horizontal_gap)
189
+ #if self.heatmap_gap < 200:
190
+ # self.heatmap_horizontal_gap += -0.025
191
+ # self.heatmap_horizontal_gap = max(0.1, self.heatmap_horizontal_gap)
192
+
193
+ def set_colorscale_mode(self, mode, colorscale_mode, colorscale_mode_64x64, show):
194
+ if mode == '64x64':
195
+ self.colorscale_mode = colorscale_mode_64x64
196
+ else:
197
+ self.colorscale_mode = colorscale_mode
198
+ #print('SHOW value', show)
199
+ self.show_colorscale = show == [True]
200
+
201
+ def cache_figure(self, fig):
202
+ if not self.check_if_figure_is_cached() and fig != {}:
203
+ key = self.get_figure_cache_key()
204
+ cached_fig = deepcopy(fig)
205
+ cached_fig.update_layout({'coloraxis1': None}, overwrite=True)
206
+ #print('CACHING FIGURE:')
207
+ self.figure_cache[key] = cached_fig
208
+
209
+ def get_cached_figure(self):
210
+ if self.check_if_figure_is_cached():
211
+ key = self.get_figure_cache_key()
212
+ fig = deepcopy(self.figure_cache[key])
213
+ else:
214
+ fig = None
215
+ return fig
216
+
217
+ def check_if_figure_is_cached(self):
218
+ key = self.get_figure_cache_key()
219
+ return key in self.figure_cache
220
+
221
+ def get_figure_cache_key(self):
222
+ return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64,
223
+ self.selected_head if not self.show_all_heads else -1, self.show_colorscale, self.colorscale_mode,
224
+ self.board.turn)
225
+ #return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64, self.selected_head if not self.show_all_heads else -1, self.heatmap_horizontal_gap, self.heatmap_fig_h, self.heatmap_fig_w)
226
+ #return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64, self.show_all_heads)
227
+
228
+ def get_side_to_move(self):
229
+ return ['Black', 'White'][self.board.turn]
230
+
231
+ def load_model(self):
232
+ if self.model_path in self.model_cache:
233
+ self.model, self.tfp = self.model_cache[self.model_path]
234
+
235
+ elif self.model_path is not None:
236
+ #net = '/home/jusufe/Projects/lc0/BT1024-3142c-swa-186000.pb.gz'
237
+ #yaml_path = '/home/jusufe/Downloads/cfg.yaml'
238
+ if not DEV_MODE:
239
+ net = self.model_path
240
+ yaml_path = self.model_yamls[self.model_path]
241
+ with open(yaml_path) as f:
242
+ cfg = f.read()
243
+ cfg = yaml.safe_load(cfg)
244
+
245
+ if 'dropout_rate' in cfg['model']:
246
+ print('Setting dropout_rate to 0.0')
247
+ cfg['model']['dropout_rate'] = 0.0
248
+
249
+ tfp = tfprocess.TFProcess(cfg)
250
+ tfp.init_net()
251
+ tfp.replace_weights(net, ignore_errors=True)
252
+ self.model = tfp.model
253
+ self.tfp = tfp
254
+ else:
255
+ self.model = DummyModel(SIMULATED_LAYERS, SIMULATED_HEADS)
256
+ self.tfp = None
257
+
258
+ else:
259
+ self.model = None
260
+ self.tfp = None
261
+
262
+ def find_models(self):
263
+ root = ROOT_DIR
264
+ models_root_folder = os.path.join(root, 'models')
265
+ model_folders = [f for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))]
266
+ model_paths = [os.path.relpath(join(models_root_folder, f)) for f in os.listdir(models_root_folder) if
267
+ isdir(join(models_root_folder, f))]
268
+ self.model_names = model_folders
269
+ self.model_paths = model_paths
270
+
271
+ #print('MODELS:')
272
+ #print(self.model_names)
273
+ #print(self.model_paths)
274
+
275
+ def find_models2(self):
276
+ import os
277
+ from os.path import isdir, join
278
+ root = ROOT_DIR
279
+ models_root_folder = os.path.join(root, 'models')
280
+ model_folders = [f for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))]
281
+ model_paths = [os.path.relpath(join(models_root_folder, f)) for f in os.listdir(models_root_folder) if
282
+ isdir(join(models_root_folder, f))]
283
+
284
+ models = []
285
+ paths = []
286
+ yamls = []
287
+ for path in model_paths:
288
+ yaml_files = [file for file in os.listdir(path) if file.endswith(".yaml")]
289
+ if len(yaml_files) != 1:
290
+ continue
291
+ model_files = [file for file in os.listdir(path) if file.endswith(".pb.gz")]
292
+ if len(model_files) == 0:
293
+ continue
294
+
295
+ models += model_files
296
+ paths += [os.path.relpath(join(path, f)) for f in model_files]
297
+ yaml_file = os.path.relpath(join(path, yaml_files[0]))
298
+ yamls += [yaml_file]*len(model_files)
299
+
300
+ self.model_yamls = {path: yaml_file for path, yaml_file in zip(paths, yamls)}
301
+ self.model_names = models
302
+ self.model_paths = paths#model_paths
303
+
304
+
305
+ def update_activations_data(self):
306
+
307
+ if self.model is not None and self.selected_layer is None:
308
+ self.selected_layer = 0
309
+
310
+ if not SIMULATE_TF:
311
+ if self.selected_layer is not None and self.model is not None and self.selected_layer != 'Smolgen':
312
+ if not DEV_MODE:
313
+ inputs = board2planes(self.board)
314
+ inputs = tf.reshape(tf.convert_to_tensor(inputs, dtype=tf.float32), [-1, 112, 8, 8])
315
+ else:
316
+ inputs = None
317
+
318
+ outputs = self.model(inputs)
319
+ self.activations_data = outputs[-1]
320
+ for i,x in enumerate(self.activations_data):
321
+ print( 'LAYERS', i, x.shape)
322
+
323
+ #smolgen = self.tfp.smol_weight_gen_dense.get_weights()[0].reshape((256, 64, 64))
324
+ #print('Smolgen')
325
+ #print(type(smolgen))
326
+ #print(smolgen.shape)
327
+ #print(type(smolgen[0]))
328
+ #print(smolgen[0].shape)
329
+ #_, _, _, self.activations_data = self.model(inputs)
330
+ elif self.selected_layer == 'Smolgen' and self.tfp is not None and self.tfp.use_smolgen:
331
+ weights = self.tfp.smol_weight_gen_dense.get_weights()[0]
332
+ self.activations_data = weights.reshape((weights.shape[0], 64, 64))
333
+ print('TYPEEEEE', type(self.activations_data))
334
+
335
+ else:
336
+ layers = SIMULATED_LAYERS
337
+ heads = SIMULATED_HEADS
338
+ self.activations_data = [np.random.rand(1, heads, 64, 64) for i in range(layers)]
339
+
340
+ if self.model is not None:
341
+
342
+ if self.model_path not in self.model_cache:
343
+ self.model_cache[self.model_path] = [self.model, self.tfp]
344
+
345
+ self.update_layers_in_body_count()
346
+
347
+ #TODO: figure out better way to determine if we have policy attention weights
348
+ #TODO: What happens if policy vis is selected and user switches to model without policy layer? Take care of this case.
349
+ if self.activations_data is not None and self.activations_data[-2].shape == (1, 8, 24):
350
+ self.has_attention_policy = True
351
+ else:
352
+ self.has_attention_policy = False
353
+ # self.update_selected_activation_data()
354
+ # self.activations = self.activations_data[self.selected_layer]
355
+
356
+ def update_grid_shape(self):
357
+ # TODO: add client side callback triggered by Interval component to save window or precise container dimensions to Div
358
+ # TODO: Trigger server side figure update callback when dimensions are recorded and store in global_data
359
+ # TODO: If needed, recalculate subplot rows and cols and container scaler based on the changed dimension
360
+
361
+ def calc_cols(heads, rows):
362
+ if heads % rows == 0:
363
+ cols = int(heads / rows)
364
+ else:
365
+ cols = int(1 + heads / rows)
366
+ return cols
367
+
368
+ if FIXED_ROW and FIXED_COL:
369
+ self.subplot_cols = FIXED_COL
370
+ self.subplot_rows = FIXED_ROW
371
+ return None
372
+
373
+ heads = self.number_of_heads
374
+ if self.subplot_mode == 'fit':
375
+ max_rows_in_screen = 4
376
+ if heads <= 4:
377
+ rows = 1
378
+ elif heads <= 8:
379
+ rows = 2
380
+ else:
381
+ rows = heads // 8 + int(heads % 8 != 0)
382
+
383
+ elif self.subplot_mode == 'big':
384
+ #print(heads)
385
+
386
+ max_rows_in_screen = 2
387
+ rows = heads // 4 + int(heads % 4 != 0)
388
+ #print(rows)
389
+
390
+ if rows > max_rows_in_screen:
391
+ container_height = f'{int((rows / max_rows_in_screen) * 100)}%'
392
+ else:
393
+ container_height = '100%'
394
+
395
+ if rows != 0:
396
+ cols = calc_cols(heads, rows)
397
+ else:
398
+ cols = 0
399
+
400
+ if self.subplot_rows != rows or self.subplot_cols != cols:
401
+ self.grid_has_changed = True
402
+
403
+ self.subplot_cols = cols
404
+ self.subplot_rows = rows
405
+
406
+ if self.show_all_heads:
407
+ self.figure_container_height = container_height
408
+ else:
409
+ self.figure_container_height = '100%'
410
+
411
+ def update_selected_activation_data(self):
412
+ # import numpy as np
413
+ # self.activations = activations_array + np.random.rand(8, 64, 64)
414
+ if self.activations_data is not None:
415
+ if self.selected_layer not in ('Policy', 'Smolgen'):
416
+ if not DEV_MODE:
417
+ activations = tf.squeeze(self.activations_data[self.selected_layer], axis=0).numpy()
418
+ #self.activations = activations[:, ::-1, :] #Flip along y-axis
419
+ else:
420
+ activations = np.squeeze(self.activations_data[self.selected_layer], axis=0)
421
+ elif self.selected_layer == 'Policy':
422
+ print('RAW POLICY SHAPE', self.activations_data[-1].shape)
423
+ activations = self.activations_data[-1].numpy()
424
+ #print('POLICY SHAPE', activations.shape)
425
+
426
+ #print('RAW POLICY SHAPE', self.activations_data[-1].shape)
427
+ #activations = np.squeeze(self.activations_data[-1].numpy(), axis=0) #shape 64,64
428
+ #promo = np.squeeze(self.activations_data[-2].numpy(), axis=0) #shape 8,24
429
+ #print('promo shape:', promo.shape)
430
+ #if self.board.turn:
431
+ # pad_shape = (48, 8)
432
+ #else:
433
+ # pad_shape = (8, 48)
434
+ #promo_padded = np.pad(promo, (pad_shape, (0, 0)), mode='constant', constant_values=None) #shape 64,24
435
+ #self.activations = np.expand_dims(np.concatenate((activations, promo_padded), axis=1), axis=0)#shape 1,64,88
436
+ #print('POLICY SHAPE', self.activations.shape)
437
+ elif self.selected_layer == 'Smolgen':
438
+ activations = self.tfp.smol_weight_gen_dense.get_weights()[0].reshape((256, 64, 64))
439
+
440
+ self.activations = activations[:, ::-1, :] # Flip along y-axis
441
+
442
+ def set_visualization_mode(self, mode):
443
+ self.visualization_mode = mode
444
+ self.visualization_mode_is_64x64 = mode == '64x64'
445
+
446
+ def set_layer(self, layer):
447
+ self.selected_layer = layer
448
+ self.update_selected_activation_data()
449
+ if layer not in ('Policy', 'Smolgen'):
450
+ self.number_of_heads = self.activations_data[self.selected_layer].shape[1]
451
+ elif layer == 'Policy':
452
+ self.number_of_heads = 1
453
+ elif layer == 'Smolgen':
454
+ self.number_of_heads = self.activations.shape[0]
455
+ self.set_head(0)
456
+ self.update_grid_shape()
457
+
458
+ def set_head(self, head):
459
+ self.selected_head = head
460
+
461
+ def set_model(self, model):
462
+ if model != self.model_path:
463
+ self.model_path = model
464
+ self.load_model()
465
+ self.update_activations_data()
466
+ self.update_selected_activation_data()
467
+ self.number_of_heads = self.activations_data[self.selected_layer].shape[1]
468
+ if self.selected_head is None:
469
+ self.selected_head = 0
470
+ else:
471
+ self.selected_head = min(self.selected_head, self.number_of_heads - 1)
472
+ self.update_grid_shape()
473
+ if SIMULATE_TF:
474
+ sleep(2)
475
+
476
+ def update_layers_in_body_count(self):
477
+ # TODO: figure out robust way to separate attention layers in body from the rest. UPDATE: Use yaml
478
+ heads = self.activations_data[0].shape[1]
479
+ for ind, layer in enumerate(self.activations_data):
480
+ if layer.shape[1] != heads or len(layer.shape) != 4:
481
+ ind = ind - 1
482
+ break
483
+ self.nr_of_layers_in_body = ind + 1
484
+ if self.selected_layer not in ('Policy', 'Smolgen'):
485
+ self.selected_layer = min(self.selected_layer, self.nr_of_layers_in_body - 1)
486
+
487
+ def get_head_data(self, head):
488
+
489
+ if self.activations.shape[0] <= head:
490
+ return None
491
+
492
+ if self.visualization_mode == '64x64':
493
+ # print('64x64 selection')
494
+ data = self.activations[head, :, :]
495
+
496
+ elif self.visualization_mode == 'ROW':
497
+ # print('ROW selection')
498
+ if self.board.turn or self.selected_layer == 'Smolgen': #White turn to move
499
+ row = 63 - self.focused_square_ind
500
+ data = self.activations[head, row, :].reshape((8, 8))
501
+ else:
502
+ #row = self.focused_square_ind
503
+ multiples = self.focused_square_ind // 8
504
+ remainder = self.focused_square_ind % 8
505
+
506
+ a = 7 - remainder
507
+ b = multiples * 8
508
+ row = a + b
509
+ data = self.activations[head, row, :].reshape((8, 8))[::-1, :]
510
+ else:
511
+ # print('COL selection')
512
+ if self.board.turn or self.selected_layer == 'Smolgen': #White turn to move
513
+ col = self.focused_square_ind
514
+ data = self.activations[head, :, col].reshape((8, 8))[::-1, ::-1]
515
+ else:
516
+ focused = 63 - self.focused_square_ind
517
+ multiples = focused // 8
518
+ remainder = focused % 8
519
+ a = 7 - remainder
520
+ b = multiples * 8
521
+ col = a + b
522
+ #print('COL!!!!!!!!!!!!!!!!!', col, a, b, focused, self.focused_square_ind)
523
+ data = self.activations[head, :, col].reshape((8, 8))[:, ::-1]
524
+ return data
525
+
526
+ def set_fen(self, fen):
527
+ self.board.set_fen(fen)
528
+ self.fen = fen
529
+ self.update_activations_data()
530
+ self.update_selected_activation_data()
531
+
532
+ def set_board(self, board):
533
+ self.board = deepcopy(board)
534
+ self.update_activations_data()
535
+ self.update_selected_activation_data()
536
+
537
+
538
+ global_data = GlobalData()
539
+ print('global data created')
layouts/visualization_demo.slides.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "type": "slides",
3
+ "data": {}
4
+ }
leela_board.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # uci_to_idx is a list of four dicts {uci -> NN policy index}
4
+ # 0 = white, no-castling
5
+ # 1 = white, castling
6
+ # 2 = black, no-castling
7
+ # 3 = black, castling
8
+ # Black moves are flipped, and castling moves are mapped to the e8a8, e8h8, e1a1, e1h1 indexes
9
+ # from their respective UCI names
10
+
11
+ uci_to_idx = []
12
+
13
+ # The index-to-uci list originates from here:
14
+ # https://github.com/glinscott/leela-chess/blob/master/lc0/src/chess/bitboard.cc
15
+
16
+ # White, no-castling
17
+ _idx_to_move_wn = [
18
+ 'a1b1', 'a1c1', 'a1d1', 'a1e1', 'a1f1', 'a1g1', 'a1h1',
19
+ 'a1a2', 'a1b2', 'a1c2', 'a1a3', 'a1b3', 'a1c3', 'a1a4',
20
+ 'a1d4', 'a1a5', 'a1e5', 'a1a6', 'a1f6', 'a1a7', 'a1g7',
21
+ 'a1a8', 'a1h8', 'b1a1', 'b1c1', 'b1d1', 'b1e1', 'b1f1',
22
+ 'b1g1', 'b1h1', 'b1a2', 'b1b2', 'b1c2', 'b1d2', 'b1a3',
23
+ 'b1b3', 'b1c3', 'b1d3', 'b1b4', 'b1e4', 'b1b5', 'b1f5',
24
+ 'b1b6', 'b1g6', 'b1b7', 'b1h7', 'b1b8', 'c1a1', 'c1b1',
25
+ 'c1d1', 'c1e1', 'c1f1', 'c1g1', 'c1h1', 'c1a2', 'c1b2',
26
+ 'c1c2', 'c1d2', 'c1e2', 'c1a3', 'c1b3', 'c1c3', 'c1d3',
27
+ 'c1e3', 'c1c4', 'c1f4', 'c1c5', 'c1g5', 'c1c6', 'c1h6',
28
+ 'c1c7', 'c1c8', 'd1a1', 'd1b1', 'd1c1', 'd1e1', 'd1f1',
29
+ 'd1g1', 'd1h1', 'd1b2', 'd1c2', 'd1d2', 'd1e2', 'd1f2',
30
+ 'd1b3', 'd1c3', 'd1d3', 'd1e3', 'd1f3', 'd1a4', 'd1d4',
31
+ 'd1g4', 'd1d5', 'd1h5', 'd1d6', 'd1d7', 'd1d8', 'e1a1',
32
+ 'e1b1', 'e1c1', 'e1d1', 'e1f1', 'e1g1', 'e1h1', 'e1c2',
33
+ 'e1d2', 'e1e2', 'e1f2', 'e1g2', 'e1c3', 'e1d3', 'e1e3',
34
+ 'e1f3', 'e1g3', 'e1b4', 'e1e4', 'e1h4', 'e1a5', 'e1e5',
35
+ 'e1e6', 'e1e7', 'e1e8', 'f1a1', 'f1b1', 'f1c1', 'f1d1',
36
+ 'f1e1', 'f1g1', 'f1h1', 'f1d2', 'f1e2', 'f1f2', 'f1g2',
37
+ 'f1h2', 'f1d3', 'f1e3', 'f1f3', 'f1g3', 'f1h3', 'f1c4',
38
+ 'f1f4', 'f1b5', 'f1f5', 'f1a6', 'f1f6', 'f1f7', 'f1f8',
39
+ 'g1a1', 'g1b1', 'g1c1', 'g1d1', 'g1e1', 'g1f1', 'g1h1',
40
+ 'g1e2', 'g1f2', 'g1g2', 'g1h2', 'g1e3', 'g1f3', 'g1g3',
41
+ 'g1h3', 'g1d4', 'g1g4', 'g1c5', 'g1g5', 'g1b6', 'g1g6',
42
+ 'g1a7', 'g1g7', 'g1g8', 'h1a1', 'h1b1', 'h1c1', 'h1d1',
43
+ 'h1e1', 'h1f1', 'h1g1', 'h1f2', 'h1g2', 'h1h2', 'h1f3',
44
+ 'h1g3', 'h1h3', 'h1e4', 'h1h4', 'h1d5', 'h1h5', 'h1c6',
45
+ 'h1h6', 'h1b7', 'h1h7', 'h1a8', 'h1h8', 'a2a1', 'a2b1',
46
+ 'a2c1', 'a2b2', 'a2c2', 'a2d2', 'a2e2', 'a2f2', 'a2g2',
47
+ 'a2h2', 'a2a3', 'a2b3', 'a2c3', 'a2a4', 'a2b4', 'a2c4',
48
+ 'a2a5', 'a2d5', 'a2a6', 'a2e6', 'a2a7', 'a2f7', 'a2a8',
49
+ 'a2g8', 'b2a1', 'b2b1', 'b2c1', 'b2d1', 'b2a2', 'b2c2',
50
+ 'b2d2', 'b2e2', 'b2f2', 'b2g2', 'b2h2', 'b2a3', 'b2b3',
51
+ 'b2c3', 'b2d3', 'b2a4', 'b2b4', 'b2c4', 'b2d4', 'b2b5',
52
+ 'b2e5', 'b2b6', 'b2f6', 'b2b7', 'b2g7', 'b2b8', 'b2h8',
53
+ 'c2a1', 'c2b1', 'c2c1', 'c2d1', 'c2e1', 'c2a2', 'c2b2',
54
+ 'c2d2', 'c2e2', 'c2f2', 'c2g2', 'c2h2', 'c2a3', 'c2b3',
55
+ 'c2c3', 'c2d3', 'c2e3', 'c2a4', 'c2b4', 'c2c4', 'c2d4',
56
+ 'c2e4', 'c2c5', 'c2f5', 'c2c6', 'c2g6', 'c2c7', 'c2h7',
57
+ 'c2c8', 'd2b1', 'd2c1', 'd2d1', 'd2e1', 'd2f1', 'd2a2',
58
+ 'd2b2', 'd2c2', 'd2e2', 'd2f2', 'd2g2', 'd2h2', 'd2b3',
59
+ 'd2c3', 'd2d3', 'd2e3', 'd2f3', 'd2b4', 'd2c4', 'd2d4',
60
+ 'd2e4', 'd2f4', 'd2a5', 'd2d5', 'd2g5', 'd2d6', 'd2h6',
61
+ 'd2d7', 'd2d8', 'e2c1', 'e2d1', 'e2e1', 'e2f1', 'e2g1',
62
+ 'e2a2', 'e2b2', 'e2c2', 'e2d2', 'e2f2', 'e2g2', 'e2h2',
63
+ 'e2c3', 'e2d3', 'e2e3', 'e2f3', 'e2g3', 'e2c4', 'e2d4',
64
+ 'e2e4', 'e2f4', 'e2g4', 'e2b5', 'e2e5', 'e2h5', 'e2a6',
65
+ 'e2e6', 'e2e7', 'e2e8', 'f2d1', 'f2e1', 'f2f1', 'f2g1',
66
+ 'f2h1', 'f2a2', 'f2b2', 'f2c2', 'f2d2', 'f2e2', 'f2g2',
67
+ 'f2h2', 'f2d3', 'f2e3', 'f2f3', 'f2g3', 'f2h3', 'f2d4',
68
+ 'f2e4', 'f2f4', 'f2g4', 'f2h4', 'f2c5', 'f2f5', 'f2b6',
69
+ 'f2f6', 'f2a7', 'f2f7', 'f2f8', 'g2e1', 'g2f1', 'g2g1',
70
+ 'g2h1', 'g2a2', 'g2b2', 'g2c2', 'g2d2', 'g2e2', 'g2f2',
71
+ 'g2h2', 'g2e3', 'g2f3', 'g2g3', 'g2h3', 'g2e4', 'g2f4',
72
+ 'g2g4', 'g2h4', 'g2d5', 'g2g5', 'g2c6', 'g2g6', 'g2b7',
73
+ 'g2g7', 'g2a8', 'g2g8', 'h2f1', 'h2g1', 'h2h1', 'h2a2',
74
+ 'h2b2', 'h2c2', 'h2d2', 'h2e2', 'h2f2', 'h2g2', 'h2f3',
75
+ 'h2g3', 'h2h3', 'h2f4', 'h2g4', 'h2h4', 'h2e5', 'h2h5',
76
+ 'h2d6', 'h2h6', 'h2c7', 'h2h7', 'h2b8', 'h2h8', 'a3a1',
77
+ 'a3b1', 'a3c1', 'a3a2', 'a3b2', 'a3c2', 'a3b3', 'a3c3',
78
+ 'a3d3', 'a3e3', 'a3f3', 'a3g3', 'a3h3', 'a3a4', 'a3b4',
79
+ 'a3c4', 'a3a5', 'a3b5', 'a3c5', 'a3a6', 'a3d6', 'a3a7',
80
+ 'a3e7', 'a3a8', 'a3f8', 'b3a1', 'b3b1', 'b3c1', 'b3d1',
81
+ 'b3a2', 'b3b2', 'b3c2', 'b3d2', 'b3a3', 'b3c3', 'b3d3',
82
+ 'b3e3', 'b3f3', 'b3g3', 'b3h3', 'b3a4', 'b3b4', 'b3c4',
83
+ 'b3d4', 'b3a5', 'b3b5', 'b3c5', 'b3d5', 'b3b6', 'b3e6',
84
+ 'b3b7', 'b3f7', 'b3b8', 'b3g8', 'c3a1', 'c3b1', 'c3c1',
85
+ 'c3d1', 'c3e1', 'c3a2', 'c3b2', 'c3c2', 'c3d2', 'c3e2',
86
+ 'c3a3', 'c3b3', 'c3d3', 'c3e3', 'c3f3', 'c3g3', 'c3h3',
87
+ 'c3a4', 'c3b4', 'c3c4', 'c3d4', 'c3e4', 'c3a5', 'c3b5',
88
+ 'c3c5', 'c3d5', 'c3e5', 'c3c6', 'c3f6', 'c3c7', 'c3g7',
89
+ 'c3c8', 'c3h8', 'd3b1', 'd3c1', 'd3d1', 'd3e1', 'd3f1',
90
+ 'd3b2', 'd3c2', 'd3d2', 'd3e2', 'd3f2', 'd3a3', 'd3b3',
91
+ 'd3c3', 'd3e3', 'd3f3', 'd3g3', 'd3h3', 'd3b4', 'd3c4',
92
+ 'd3d4', 'd3e4', 'd3f4', 'd3b5', 'd3c5', 'd3d5', 'd3e5',
93
+ 'd3f5', 'd3a6', 'd3d6', 'd3g6', 'd3d7', 'd3h7', 'd3d8',
94
+ 'e3c1', 'e3d1', 'e3e1', 'e3f1', 'e3g1', 'e3c2', 'e3d2',
95
+ 'e3e2', 'e3f2', 'e3g2', 'e3a3', 'e3b3', 'e3c3', 'e3d3',
96
+ 'e3f3', 'e3g3', 'e3h3', 'e3c4', 'e3d4', 'e3e4', 'e3f4',
97
+ 'e3g4', 'e3c5', 'e3d5', 'e3e5', 'e3f5', 'e3g5', 'e3b6',
98
+ 'e3e6', 'e3h6', 'e3a7', 'e3e7', 'e3e8', 'f3d1', 'f3e1',
99
+ 'f3f1', 'f3g1', 'f3h1', 'f3d2', 'f3e2', 'f3f2', 'f3g2',
100
+ 'f3h2', 'f3a3', 'f3b3', 'f3c3', 'f3d3', 'f3e3', 'f3g3',
101
+ 'f3h3', 'f3d4', 'f3e4', 'f3f4', 'f3g4', 'f3h4', 'f3d5',
102
+ 'f3e5', 'f3f5', 'f3g5', 'f3h5', 'f3c6', 'f3f6', 'f3b7',
103
+ 'f3f7', 'f3a8', 'f3f8', 'g3e1', 'g3f1', 'g3g1', 'g3h1',
104
+ 'g3e2', 'g3f2', 'g3g2', 'g3h2', 'g3a3', 'g3b3', 'g3c3',
105
+ 'g3d3', 'g3e3', 'g3f3', 'g3h3', 'g3e4', 'g3f4', 'g3g4',
106
+ 'g3h4', 'g3e5', 'g3f5', 'g3g5', 'g3h5', 'g3d6', 'g3g6',
107
+ 'g3c7', 'g3g7', 'g3b8', 'g3g8', 'h3f1', 'h3g1', 'h3h1',
108
+ 'h3f2', 'h3g2', 'h3h2', 'h3a3', 'h3b3', 'h3c3', 'h3d3',
109
+ 'h3e3', 'h3f3', 'h3g3', 'h3f4', 'h3g4', 'h3h4', 'h3f5',
110
+ 'h3g5', 'h3h5', 'h3e6', 'h3h6', 'h3d7', 'h3h7', 'h3c8',
111
+ 'h3h8', 'a4a1', 'a4d1', 'a4a2', 'a4b2', 'a4c2', 'a4a3',
112
+ 'a4b3', 'a4c3', 'a4b4', 'a4c4', 'a4d4', 'a4e4', 'a4f4',
113
+ 'a4g4', 'a4h4', 'a4a5', 'a4b5', 'a4c5', 'a4a6', 'a4b6',
114
+ 'a4c6', 'a4a7', 'a4d7', 'a4a8', 'a4e8', 'b4b1', 'b4e1',
115
+ 'b4a2', 'b4b2', 'b4c2', 'b4d2', 'b4a3', 'b4b3', 'b4c3',
116
+ 'b4d3', 'b4a4', 'b4c4', 'b4d4', 'b4e4', 'b4f4', 'b4g4',
117
+ 'b4h4', 'b4a5', 'b4b5', 'b4c5', 'b4d5', 'b4a6', 'b4b6',
118
+ 'b4c6', 'b4d6', 'b4b7', 'b4e7', 'b4b8', 'b4f8', 'c4c1',
119
+ 'c4f1', 'c4a2', 'c4b2', 'c4c2', 'c4d2', 'c4e2', 'c4a3',
120
+ 'c4b3', 'c4c3', 'c4d3', 'c4e3', 'c4a4', 'c4b4', 'c4d4',
121
+ 'c4e4', 'c4f4', 'c4g4', 'c4h4', 'c4a5', 'c4b5', 'c4c5',
122
+ 'c4d5', 'c4e5', 'c4a6', 'c4b6', 'c4c6', 'c4d6', 'c4e6',
123
+ 'c4c7', 'c4f7', 'c4c8', 'c4g8', 'd4a1', 'd4d1', 'd4g1',
124
+ 'd4b2', 'd4c2', 'd4d2', 'd4e2', 'd4f2', 'd4b3', 'd4c3',
125
+ 'd4d3', 'd4e3', 'd4f3', 'd4a4', 'd4b4', 'd4c4', 'd4e4',
126
+ 'd4f4', 'd4g4', 'd4h4', 'd4b5', 'd4c5', 'd4d5', 'd4e5',
127
+ 'd4f5', 'd4b6', 'd4c6', 'd4d6', 'd4e6', 'd4f6', 'd4a7',
128
+ 'd4d7', 'd4g7', 'd4d8', 'd4h8', 'e4b1', 'e4e1', 'e4h1',
129
+ 'e4c2', 'e4d2', 'e4e2', 'e4f2', 'e4g2', 'e4c3', 'e4d3',
130
+ 'e4e3', 'e4f3', 'e4g3', 'e4a4', 'e4b4', 'e4c4', 'e4d4',
131
+ 'e4f4', 'e4g4', 'e4h4', 'e4c5', 'e4d5', 'e4e5', 'e4f5',
132
+ 'e4g5', 'e4c6', 'e4d6', 'e4e6', 'e4f6', 'e4g6', 'e4b7',
133
+ 'e4e7', 'e4h7', 'e4a8', 'e4e8', 'f4c1', 'f4f1', 'f4d2',
134
+ 'f4e2', 'f4f2', 'f4g2', 'f4h2', 'f4d3', 'f4e3', 'f4f3',
135
+ 'f4g3', 'f4h3', 'f4a4', 'f4b4', 'f4c4', 'f4d4', 'f4e4',
136
+ 'f4g4', 'f4h4', 'f4d5', 'f4e5', 'f4f5', 'f4g5', 'f4h5',
137
+ 'f4d6', 'f4e6', 'f4f6', 'f4g6', 'f4h6', 'f4c7', 'f4f7',
138
+ 'f4b8', 'f4f8', 'g4d1', 'g4g1', 'g4e2', 'g4f2', 'g4g2',
139
+ 'g4h2', 'g4e3', 'g4f3', 'g4g3', 'g4h3', 'g4a4', 'g4b4',
140
+ 'g4c4', 'g4d4', 'g4e4', 'g4f4', 'g4h4', 'g4e5', 'g4f5',
141
+ 'g4g5', 'g4h5', 'g4e6', 'g4f6', 'g4g6', 'g4h6', 'g4d7',
142
+ 'g4g7', 'g4c8', 'g4g8', 'h4e1', 'h4h1', 'h4f2', 'h4g2',
143
+ 'h4h2', 'h4f3', 'h4g3', 'h4h3', 'h4a4', 'h4b4', 'h4c4',
144
+ 'h4d4', 'h4e4', 'h4f4', 'h4g4', 'h4f5', 'h4g5', 'h4h5',
145
+ 'h4f6', 'h4g6', 'h4h6', 'h4e7', 'h4h7', 'h4d8', 'h4h8',
146
+ 'a5a1', 'a5e1', 'a5a2', 'a5d2', 'a5a3', 'a5b3', 'a5c3',
147
+ 'a5a4', 'a5b4', 'a5c4', 'a5b5', 'a5c5', 'a5d5', 'a5e5',
148
+ 'a5f5', 'a5g5', 'a5h5', 'a5a6', 'a5b6', 'a5c6', 'a5a7',
149
+ 'a5b7', 'a5c7', 'a5a8', 'a5d8', 'b5b1', 'b5f1', 'b5b2',
150
+ 'b5e2', 'b5a3', 'b5b3', 'b5c3', 'b5d3', 'b5a4', 'b5b4',
151
+ 'b5c4', 'b5d4', 'b5a5', 'b5c5', 'b5d5', 'b5e5', 'b5f5',
152
+ 'b5g5', 'b5h5', 'b5a6', 'b5b6', 'b5c6', 'b5d6', 'b5a7',
153
+ 'b5b7', 'b5c7', 'b5d7', 'b5b8', 'b5e8', 'c5c1', 'c5g1',
154
+ 'c5c2', 'c5f2', 'c5a3', 'c5b3', 'c5c3', 'c5d3', 'c5e3',
155
+ 'c5a4', 'c5b4', 'c5c4', 'c5d4', 'c5e4', 'c5a5', 'c5b5',
156
+ 'c5d5', 'c5e5', 'c5f5', 'c5g5', 'c5h5', 'c5a6', 'c5b6',
157
+ 'c5c6', 'c5d6', 'c5e6', 'c5a7', 'c5b7', 'c5c7', 'c5d7',
158
+ 'c5e7', 'c5c8', 'c5f8', 'd5d1', 'd5h1', 'd5a2', 'd5d2',
159
+ 'd5g2', 'd5b3', 'd5c3', 'd5d3', 'd5e3', 'd5f3', 'd5b4',
160
+ 'd5c4', 'd5d4', 'd5e4', 'd5f4', 'd5a5', 'd5b5', 'd5c5',
161
+ 'd5e5', 'd5f5', 'd5g5', 'd5h5', 'd5b6', 'd5c6', 'd5d6',
162
+ 'd5e6', 'd5f6', 'd5b7', 'd5c7', 'd5d7', 'd5e7', 'd5f7',
163
+ 'd5a8', 'd5d8', 'd5g8', 'e5a1', 'e5e1', 'e5b2', 'e5e2',
164
+ 'e5h2', 'e5c3', 'e5d3', 'e5e3', 'e5f3', 'e5g3', 'e5c4',
165
+ 'e5d4', 'e5e4', 'e5f4', 'e5g4', 'e5a5', 'e5b5', 'e5c5',
166
+ 'e5d5', 'e5f5', 'e5g5', 'e5h5', 'e5c6', 'e5d6', 'e5e6',
167
+ 'e5f6', 'e5g6', 'e5c7', 'e5d7', 'e5e7', 'e5f7', 'e5g7',
168
+ 'e5b8', 'e5e8', 'e5h8', 'f5b1', 'f5f1', 'f5c2', 'f5f2',
169
+ 'f5d3', 'f5e3', 'f5f3', 'f5g3', 'f5h3', 'f5d4', 'f5e4',
170
+ 'f5f4', 'f5g4', 'f5h4', 'f5a5', 'f5b5', 'f5c5', 'f5d5',
171
+ 'f5e5', 'f5g5', 'f5h5', 'f5d6', 'f5e6', 'f5f6', 'f5g6',
172
+ 'f5h6', 'f5d7', 'f5e7', 'f5f7', 'f5g7', 'f5h7', 'f5c8',
173
+ 'f5f8', 'g5c1', 'g5g1', 'g5d2', 'g5g2', 'g5e3', 'g5f3',
174
+ 'g5g3', 'g5h3', 'g5e4', 'g5f4', 'g5g4', 'g5h4', 'g5a5',
175
+ 'g5b5', 'g5c5', 'g5d5', 'g5e5', 'g5f5', 'g5h5', 'g5e6',
176
+ 'g5f6', 'g5g6', 'g5h6', 'g5e7', 'g5f7', 'g5g7', 'g5h7',
177
+ 'g5d8', 'g5g8', 'h5d1', 'h5h1', 'h5e2', 'h5h2', 'h5f3',
178
+ 'h5g3', 'h5h3', 'h5f4', 'h5g4', 'h5h4', 'h5a5', 'h5b5',
179
+ 'h5c5', 'h5d5', 'h5e5', 'h5f5', 'h5g5', 'h5f6', 'h5g6',
180
+ 'h5h6', 'h5f7', 'h5g7', 'h5h7', 'h5e8', 'h5h8', 'a6a1',
181
+ 'a6f1', 'a6a2', 'a6e2', 'a6a3', 'a6d3', 'a6a4', 'a6b4',
182
+ 'a6c4', 'a6a5', 'a6b5', 'a6c5', 'a6b6', 'a6c6', 'a6d6',
183
+ 'a6e6', 'a6f6', 'a6g6', 'a6h6', 'a6a7', 'a6b7', 'a6c7',
184
+ 'a6a8', 'a6b8', 'a6c8', 'b6b1', 'b6g1', 'b6b2', 'b6f2',
185
+ 'b6b3', 'b6e3', 'b6a4', 'b6b4', 'b6c4', 'b6d4', 'b6a5',
186
+ 'b6b5', 'b6c5', 'b6d5', 'b6a6', 'b6c6', 'b6d6', 'b6e6',
187
+ 'b6f6', 'b6g6', 'b6h6', 'b6a7', 'b6b7', 'b6c7', 'b6d7',
188
+ 'b6a8', 'b6b8', 'b6c8', 'b6d8', 'c6c1', 'c6h1', 'c6c2',
189
+ 'c6g2', 'c6c3', 'c6f3', 'c6a4', 'c6b4', 'c6c4', 'c6d4',
190
+ 'c6e4', 'c6a5', 'c6b5', 'c6c5', 'c6d5', 'c6e5', 'c6a6',
191
+ 'c6b6', 'c6d6', 'c6e6', 'c6f6', 'c6g6', 'c6h6', 'c6a7',
192
+ 'c6b7', 'c6c7', 'c6d7', 'c6e7', 'c6a8', 'c6b8', 'c6c8',
193
+ 'c6d8', 'c6e8', 'd6d1', 'd6d2', 'd6h2', 'd6a3', 'd6d3',
194
+ 'd6g3', 'd6b4', 'd6c4', 'd6d4', 'd6e4', 'd6f4', 'd6b5',
195
+ 'd6c5', 'd6d5', 'd6e5', 'd6f5', 'd6a6', 'd6b6', 'd6c6',
196
+ 'd6e6', 'd6f6', 'd6g6', 'd6h6', 'd6b7', 'd6c7', 'd6d7',
197
+ 'd6e7', 'd6f7', 'd6b8', 'd6c8', 'd6d8', 'd6e8', 'd6f8',
198
+ 'e6e1', 'e6a2', 'e6e2', 'e6b3', 'e6e3', 'e6h3', 'e6c4',
199
+ 'e6d4', 'e6e4', 'e6f4', 'e6g4', 'e6c5', 'e6d5', 'e6e5',
200
+ 'e6f5', 'e6g5', 'e6a6', 'e6b6', 'e6c6', 'e6d6', 'e6f6',
201
+ 'e6g6', 'e6h6', 'e6c7', 'e6d7', 'e6e7', 'e6f7', 'e6g7',
202
+ 'e6c8', 'e6d8', 'e6e8', 'e6f8', 'e6g8', 'f6a1', 'f6f1',
203
+ 'f6b2', 'f6f2', 'f6c3', 'f6f3', 'f6d4', 'f6e4', 'f6f4',
204
+ 'f6g4', 'f6h4', 'f6d5', 'f6e5', 'f6f5', 'f6g5', 'f6h5',
205
+ 'f6a6', 'f6b6', 'f6c6', 'f6d6', 'f6e6', 'f6g6', 'f6h6',
206
+ 'f6d7', 'f6e7', 'f6f7', 'f6g7', 'f6h7', 'f6d8', 'f6e8',
207
+ 'f6f8', 'f6g8', 'f6h8', 'g6b1', 'g6g1', 'g6c2', 'g6g2',
208
+ 'g6d3', 'g6g3', 'g6e4', 'g6f4', 'g6g4', 'g6h4', 'g6e5',
209
+ 'g6f5', 'g6g5', 'g6h5', 'g6a6', 'g6b6', 'g6c6', 'g6d6',
210
+ 'g6e6', 'g6f6', 'g6h6', 'g6e7', 'g6f7', 'g6g7', 'g6h7',
211
+ 'g6e8', 'g6f8', 'g6g8', 'g6h8', 'h6c1', 'h6h1', 'h6d2',
212
+ 'h6h2', 'h6e3', 'h6h3', 'h6f4', 'h6g4', 'h6h4', 'h6f5',
213
+ 'h6g5', 'h6h5', 'h6a6', 'h6b6', 'h6c6', 'h6d6', 'h6e6',
214
+ 'h6f6', 'h6g6', 'h6f7', 'h6g7', 'h6h7', 'h6f8', 'h6g8',
215
+ 'h6h8', 'a7a1', 'a7g1', 'a7a2', 'a7f2', 'a7a3', 'a7e3',
216
+ 'a7a4', 'a7d4', 'a7a5', 'a7b5', 'a7c5', 'a7a6', 'a7b6',
217
+ 'a7c6', 'a7b7', 'a7c7', 'a7d7', 'a7e7', 'a7f7', 'a7g7',
218
+ 'a7h7', 'a7a8', 'a7b8', 'a7c8', 'b7b1', 'b7h1', 'b7b2',
219
+ 'b7g2', 'b7b3', 'b7f3', 'b7b4', 'b7e4', 'b7a5', 'b7b5',
220
+ 'b7c5', 'b7d5', 'b7a6', 'b7b6', 'b7c6', 'b7d6', 'b7a7',
221
+ 'b7c7', 'b7d7', 'b7e7', 'b7f7', 'b7g7', 'b7h7', 'b7a8',
222
+ 'b7b8', 'b7c8', 'b7d8', 'c7c1', 'c7c2', 'c7h2', 'c7c3',
223
+ 'c7g3', 'c7c4', 'c7f4', 'c7a5', 'c7b5', 'c7c5', 'c7d5',
224
+ 'c7e5', 'c7a6', 'c7b6', 'c7c6', 'c7d6', 'c7e6', 'c7a7',
225
+ 'c7b7', 'c7d7', 'c7e7', 'c7f7', 'c7g7', 'c7h7', 'c7a8',
226
+ 'c7b8', 'c7c8', 'c7d8', 'c7e8', 'd7d1', 'd7d2', 'd7d3',
227
+ 'd7h3', 'd7a4', 'd7d4', 'd7g4', 'd7b5', 'd7c5', 'd7d5',
228
+ 'd7e5', 'd7f5', 'd7b6', 'd7c6', 'd7d6', 'd7e6', 'd7f6',
229
+ 'd7a7', 'd7b7', 'd7c7', 'd7e7', 'd7f7', 'd7g7', 'd7h7',
230
+ 'd7b8', 'd7c8', 'd7d8', 'd7e8', 'd7f8', 'e7e1', 'e7e2',
231
+ 'e7a3', 'e7e3', 'e7b4', 'e7e4', 'e7h4', 'e7c5', 'e7d5',
232
+ 'e7e5', 'e7f5', 'e7g5', 'e7c6', 'e7d6', 'e7e6', 'e7f6',
233
+ 'e7g6', 'e7a7', 'e7b7', 'e7c7', 'e7d7', 'e7f7', 'e7g7',
234
+ 'e7h7', 'e7c8', 'e7d8', 'e7e8', 'e7f8', 'e7g8', 'f7f1',
235
+ 'f7a2', 'f7f2', 'f7b3', 'f7f3', 'f7c4', 'f7f4', 'f7d5',
236
+ 'f7e5', 'f7f5', 'f7g5', 'f7h5', 'f7d6', 'f7e6', 'f7f6',
237
+ 'f7g6', 'f7h6', 'f7a7', 'f7b7', 'f7c7', 'f7d7', 'f7e7',
238
+ 'f7g7', 'f7h7', 'f7d8', 'f7e8', 'f7f8', 'f7g8', 'f7h8',
239
+ 'g7a1', 'g7g1', 'g7b2', 'g7g2', 'g7c3', 'g7g3', 'g7d4',
240
+ 'g7g4', 'g7e5', 'g7f5', 'g7g5', 'g7h5', 'g7e6', 'g7f6',
241
+ 'g7g6', 'g7h6', 'g7a7', 'g7b7', 'g7c7', 'g7d7', 'g7e7',
242
+ 'g7f7', 'g7h7', 'g7e8', 'g7f8', 'g7g8', 'g7h8', 'h7b1',
243
+ 'h7h1', 'h7c2', 'h7h2', 'h7d3', 'h7h3', 'h7e4', 'h7h4',
244
+ 'h7f5', 'h7g5', 'h7h5', 'h7f6', 'h7g6', 'h7h6', 'h7a7',
245
+ 'h7b7', 'h7c7', 'h7d7', 'h7e7', 'h7f7', 'h7g7', 'h7f8',
246
+ 'h7g8', 'h7h8', 'a8a1', 'a8h1', 'a8a2', 'a8g2', 'a8a3',
247
+ 'a8f3', 'a8a4', 'a8e4', 'a8a5', 'a8d5', 'a8a6', 'a8b6',
248
+ 'a8c6', 'a8a7', 'a8b7', 'a8c7', 'a8b8', 'a8c8', 'a8d8',
249
+ 'a8e8', 'a8f8', 'a8g8', 'a8h8', 'b8b1', 'b8b2', 'b8h2',
250
+ 'b8b3', 'b8g3', 'b8b4', 'b8f4', 'b8b5', 'b8e5', 'b8a6',
251
+ 'b8b6', 'b8c6', 'b8d6', 'b8a7', 'b8b7', 'b8c7', 'b8d7',
252
+ 'b8a8', 'b8c8', 'b8d8', 'b8e8', 'b8f8', 'b8g8', 'b8h8',
253
+ 'c8c1', 'c8c2', 'c8c3', 'c8h3', 'c8c4', 'c8g4', 'c8c5',
254
+ 'c8f5', 'c8a6', 'c8b6', 'c8c6', 'c8d6', 'c8e6', 'c8a7',
255
+ 'c8b7', 'c8c7', 'c8d7', 'c8e7', 'c8a8', 'c8b8', 'c8d8',
256
+ 'c8e8', 'c8f8', 'c8g8', 'c8h8', 'd8d1', 'd8d2', 'd8d3',
257
+ 'd8d4', 'd8h4', 'd8a5', 'd8d5', 'd8g5', 'd8b6', 'd8c6',
258
+ 'd8d6', 'd8e6', 'd8f6', 'd8b7', 'd8c7', 'd8d7', 'd8e7',
259
+ 'd8f7', 'd8a8', 'd8b8', 'd8c8', 'd8e8', 'd8f8', 'd8g8',
260
+ 'd8h8', 'e8e1', 'e8e2', 'e8e3', 'e8a4', 'e8e4', 'e8b5',
261
+ 'e8e5', 'e8h5', 'e8c6', 'e8d6', 'e8e6', 'e8f6', 'e8g6',
262
+ 'e8c7', 'e8d7', 'e8e7', 'e8f7', 'e8g7', 'e8a8', 'e8b8',
263
+ 'e8c8', 'e8d8', 'e8f8', 'e8g8', 'e8h8', 'f8f1', 'f8f2',
264
+ 'f8a3', 'f8f3', 'f8b4', 'f8f4', 'f8c5', 'f8f5', 'f8d6',
265
+ 'f8e6', 'f8f6', 'f8g6', 'f8h6', 'f8d7', 'f8e7', 'f8f7',
266
+ 'f8g7', 'f8h7', 'f8a8', 'f8b8', 'f8c8', 'f8d8', 'f8e8',
267
+ 'f8g8', 'f8h8', 'g8g1', 'g8a2', 'g8g2', 'g8b3', 'g8g3',
268
+ 'g8c4', 'g8g4', 'g8d5', 'g8g5', 'g8e6', 'g8f6', 'g8g6',
269
+ 'g8h6', 'g8e7', 'g8f7', 'g8g7', 'g8h7', 'g8a8', 'g8b8',
270
+ 'g8c8', 'g8d8', 'g8e8', 'g8f8', 'g8h8', 'h8a1', 'h8h1',
271
+ 'h8b2', 'h8h2', 'h8c3', 'h8h3', 'h8d4', 'h8h4', 'h8e5',
272
+ 'h8h5', 'h8f6', 'h8g6', 'h8h6', 'h8f7', 'h8g7', 'h8h7',
273
+ 'h8a8', 'h8b8', 'h8c8', 'h8d8', 'h8e8', 'h8f8', 'h8g8',
274
+ 'a7a8q', 'a7a8r', 'a7a8b', 'a7b8q', 'a7b8r', 'a7b8b', 'b7a8q',
275
+ 'b7a8r', 'b7a8b', 'b7b8q', 'b7b8r', 'b7b8b', 'b7c8q', 'b7c8r',
276
+ 'b7c8b', 'c7b8q', 'c7b8r', 'c7b8b', 'c7c8q', 'c7c8r', 'c7c8b',
277
+ 'c7d8q', 'c7d8r', 'c7d8b', 'd7c8q', 'd7c8r', 'd7c8b', 'd7d8q',
278
+ 'd7d8r', 'd7d8b', 'd7e8q', 'd7e8r', 'd7e8b', 'e7d8q', 'e7d8r',
279
+ 'e7d8b', 'e7e8q', 'e7e8r', 'e7e8b', 'e7f8q', 'e7f8r', 'e7f8b',
280
+ 'f7e8q', 'f7e8r', 'f7e8b', 'f7f8q', 'f7f8r', 'f7f8b', 'f7g8q',
281
+ 'f7g8r', 'f7g8b', 'g7f8q', 'g7f8r', 'g7f8b', 'g7g8q', 'g7g8r',
282
+ 'g7g8b', 'g7h8q', 'g7h8r', 'g7h8b', 'h7g8q', 'h7g8r', 'h7g8b',
283
+ 'h7h8q', 'h7h8r', 'h7h8b'
284
+ ]
285
+
286
+ # White, no castling
287
+ _uci_to_idx_wn = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_wn))
288
+
289
+ # White, castling
290
+ _uci_to_idx_wc = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_wn))
291
+ _uci_to_idx_wc['e1g1'], _uci_to_idx_wc['e1h1'] = _uci_to_idx_wc['e1h1'], _uci_to_idx_wc['e1g1']
292
+ _uci_to_idx_wc['e1c1'], _uci_to_idx_wc['e1a1'] = _uci_to_idx_wc['e1a1'], _uci_to_idx_wc['e1c1']
293
+
294
+
295
+ # Black, no castling
296
+ _idx_to_move_bn = []
297
+ for move in _idx_to_move_wn:
298
+ c0,r0,c1,r1,p = move[0],int(move[1]),move[2],int(move[3]),move[4:]
299
+ r0 = 9 - r0
300
+ r1 = 9 - r1
301
+ _idx_to_move_bn.append('{}{}{}{}{}'.format(c0,r0,c1,r1,p))
302
+ _uci_to_idx_bn = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_bn))
303
+
304
+ # Black, castling
305
+ _uci_to_idx_bc = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_bn))
306
+ _uci_to_idx_bc['e8g8'], _uci_to_idx_bc['e8h8'] = _uci_to_idx_bc['e8h8'], _uci_to_idx_bc['e8g8']
307
+ _uci_to_idx_bc['e8c8'], _uci_to_idx_bc['e8a8'] = _uci_to_idx_bc['e8a8'], _uci_to_idx_bc['e8c8']
308
+
309
+ uci_to_idx = [_uci_to_idx_wn, _uci_to_idx_wc, _uci_to_idx_bn, _uci_to_idx_bc]
310
+
311
+
312
+ import collections
313
+ import struct
314
+ import zlib
315
+
316
+ import chess
317
+ import numpy as np
318
+ from chess import Move
319
+
320
+ flat_planes = []
321
+ for i in range(256):
322
+ flat_planes.append(np.ones((8,8), dtype=np.uint8)*i)
323
+
324
+ LeelaBoardData = collections.namedtuple('LeelaBoardData',
325
+ 'plane_bytes repetition '
326
+ 'transposition_key us_ooo us_oo them_ooo them_oo '
327
+ 'side_to_move rule50_count')
328
+
329
+ def pc_board_property(propertyname):
330
+ '''Create a property based on self.pc_board'''
331
+ def prop(self):
332
+ return getattr(self.pc_board, propertyname)
333
+ return property(prop)
334
+
335
+ class LeelaBoard:
336
+ turn = pc_board_property('turn')
337
+ move_stack = pc_board_property('move_stack')
338
+ _plane_bytes_struct = struct.Struct('>Q')
339
+
340
+ def __init__(self, leela_board = None, *args, **kwargs):
341
+ '''If leela_board is passed as an argument, return a copy'''
342
+ self.pc_board = chess.Board(*args, **kwargs)
343
+ self.lcz_stack = []
344
+ self._lcz_transposition_counter = collections.Counter()
345
+ self._lcz_push()
346
+ self.is_game_over = self.pc_method('is_game_over')
347
+ self.can_claim_draw = self.pc_method('can_claim_draw')
348
+ self.generate_legal_moves = self.pc_method('generate_legal_moves')
349
+
350
+ def copy(self, history=7):
351
+ """Note! Currently the copy constructor uses pc_board.copy(stack=False), which makes pops impossible"""
352
+ cls = type(self)
353
+ copied = cls.__new__(cls)
354
+ copied.pc_board = self.pc_board.copy(stack=False)
355
+ copied.pc_board.stack[:] = self.pc_board.stack[-history:]
356
+ copied.pc_board.move_stack[:] = self.pc_board.move_stack[-history:]
357
+ copied.lcz_stack = self.lcz_stack[-history:]
358
+ copied._lcz_transposition_counter = self._lcz_transposition_counter.copy()
359
+ copied.is_game_over = copied.pc_method('is_game_over')
360
+ copied.can_claim_draw = copied.pc_method('can_claim_draw')
361
+ copied.generate_legal_moves = copied.pc_method('generate_legal_moves')
362
+ return copied
363
+
364
+ def pc_method(self, methodname):
365
+ '''Return attribute of self.pc_board, useful for copying method bindings'''
366
+ return getattr(self.pc_board, methodname)
367
+
368
+ def is_threefold(self):
369
+ transposition_key = self.pc_board._transposition_key()
370
+ return self._lcz_transposition_counter[transposition_key] >= 3
371
+
372
+ def is_fifty_moves(self):
373
+ return self.pc_board.halfmove_clock >= 100
374
+
375
+ def is_draw(self):
376
+ return self.is_threefold() or self.is_fifty_moves()
377
+
378
+ def push(self, move):
379
+ self.pc_board.push(move)
380
+ self._lcz_push()
381
+
382
+ def push_uci(self, uci):
383
+ # don't check for legality - it takes much longer to run...
384
+ # self.pc_board.push_uci(uci)
385
+ self.pc_board.push(Move.from_uci(uci))
386
+ self._lcz_push()
387
+
388
+ def push_san(self, san):
389
+ self.pc_board.push_san(san)
390
+ self._lcz_push()
391
+
392
+ def pop(self):
393
+ result = self.pc_board.pop()
394
+ _lcz_data = self.lcz_stack.pop()
395
+ self._lcz_transposition_counter.subtract((_lcz_data.transposition_key,))
396
+ return result
397
+
398
+ def _plane_bytes_iter(self):
399
+ """Get plane bytes... used for _lcz_push"""
400
+ pack = self._plane_bytes_struct.pack
401
+ pieces_mask = self.pc_board.pieces_mask
402
+ for color in (True, False):
403
+ for piece_type in range(1,7):
404
+ byts = pack(pieces_mask(piece_type, color))
405
+ yield byts
406
+
407
+ def _lcz_push(self):
408
+ """Push data onto the lcz data stack after pushing board moves"""
409
+ transposition_key = self.pc_board._transposition_key()
410
+ self._lcz_transposition_counter.update((transposition_key,))
411
+ repetitions = self._lcz_transposition_counter[transposition_key] - 1
412
+ # side_to_move = 0 if we're white, 1 if we're black
413
+ side_to_move = 0 if self.pc_board.turn else 1
414
+ rule50_count = self.pc_board.halfmove_clock
415
+ # Figure out castling rights
416
+ if not side_to_move:
417
+ # we're white
418
+ _c = self.pc_board.castling_rights
419
+ us_ooo, us_oo = (_c>>chess.A1) & 1, (_c>>chess.H1) & 1
420
+ them_ooo, them_oo = (_c>>chess.A8) & 1, (_c>>chess.H8) & 1
421
+ else:
422
+ # We're black
423
+ _c = self.pc_board.castling_rights
424
+ us_ooo, us_oo = (_c>>chess.A8) & 1, (_c>>chess.H8) & 1
425
+ them_ooo, them_oo = (_c>>chess.A1) & 1, (_c>>chess.H1) & 1
426
+ # Create 13 planes... 6 us, 6 them, repetitions>=1
427
+ plane_bytes = b''.join(self._plane_bytes_iter())
428
+ repetition = (repetitions>=1)
429
+ lcz_data = LeelaBoardData(
430
+ plane_bytes, repetition=repetition,
431
+ transposition_key=transposition_key,
432
+ us_ooo=us_ooo, us_oo=us_oo, them_ooo=them_ooo, them_oo=them_oo,
433
+ side_to_move=side_to_move, rule50_count=rule50_count
434
+ )
435
+ self.lcz_stack.append(lcz_data)
436
+
437
+ def serialize_features(self):
438
+ '''Get compacted bytes representation of input planes'''
439
+ planes = []
440
+ curdata = self.lcz_stack[-1]
441
+ bytes_false_true = bytes([False]), bytes([True])
442
+ bytes_per_history = 97
443
+ total_plane_bytes = bytes_per_history * 8
444
+ def bytes_iter():
445
+ plane_bytes_yielded = 0
446
+ for data in self.lcz_stack[-1:-9:-1]:
447
+ yield data.plane_bytes
448
+ yield bytes_false_true[data.repetition]
449
+ plane_bytes_yielded += bytes_per_history
450
+ # 104 total piece planes... fill in missing with 0s
451
+ yield bytes(total_plane_bytes - plane_bytes_yielded)
452
+ # Yield the rest of the constant planes
453
+ yield np.packbits((curdata.us_ooo,
454
+ curdata.us_oo,
455
+ curdata.them_ooo,
456
+ curdata.them_oo,
457
+ curdata.side_to_move)).tobytes()
458
+ yield chr(curdata.rule50_count).encode()
459
+ return b''.join(bytes_iter())
460
+
461
+ @classmethod
462
+ def deserialize_features(cls, serialized):
463
+ planes_stack = []
464
+ rule50_count = serialized[-1] # last byte is rule 50
465
+ board_attrs = np.unpackbits(memoryview(serialized[-2:-1])) # second to last byte
466
+ us_ooo, us_oo, them_ooo, them_oo, side_to_move = board_attrs[:5]
467
+ bytes_per_history = 97
468
+ for history_idx in range(0, bytes_per_history*8, bytes_per_history):
469
+ plane_bytes = serialized[history_idx:history_idx+96]
470
+ repetition = serialized[history_idx+96]
471
+ if not side_to_move:
472
+ # we're white
473
+ planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
474
+ .reshape(12, 8, 8)[::-1])
475
+ else:
476
+ # We're black
477
+ planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
478
+ .reshape(12, 8, 8)[::-1]
479
+ .reshape(2,6,8,8)[::-1,:,::-1]
480
+ .reshape(12, 8,8))
481
+ planes_stack.append(planes)
482
+ planes_stack.append([flat_planes[repetition]])
483
+ planes_stack.append([flat_planes[us_ooo],
484
+ flat_planes[us_oo],
485
+ flat_planes[them_ooo],
486
+ flat_planes[them_oo],
487
+ flat_planes[side_to_move],
488
+ flat_planes[rule50_count],
489
+ flat_planes[0],
490
+ flat_planes[1]])
491
+ planes = np.concatenate(planes_stack)
492
+ return planes
493
+
494
+ def lcz_features(self):
495
+ '''Get neural network input planes as uint8'''
496
+ # print(list(self._planes_iter()))
497
+ planes_stack = []
498
+ curdata = self.lcz_stack[-1]
499
+ planes_yielded = 0
500
+ for data in self.lcz_stack[-1:-9:-1]:
501
+ plane_bytes = data.plane_bytes
502
+ if not curdata.side_to_move:
503
+ # we're white
504
+ planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
505
+ .reshape(12, 8, 8)[::-1])
506
+ else:
507
+ # We're black
508
+ planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
509
+ .reshape(12, 8, 8)[::-1]
510
+ .reshape(2,6,8,8)[::-1,:,::-1]
511
+ .reshape(12, 8,8))
512
+ planes_stack.append(planes)
513
+ planes_stack.append([flat_planes[data.repetition]])
514
+ planes_yielded += 13
515
+ empty_planes = [flat_planes[0] for _ in range(104-planes_yielded)]
516
+ if empty_planes:
517
+ planes_stack.append(empty_planes)
518
+ # Yield the rest of the constant planes
519
+ planes_stack.append([flat_planes[curdata.us_ooo],
520
+ flat_planes[curdata.us_oo],
521
+ flat_planes[curdata.them_ooo],
522
+ flat_planes[curdata.them_oo],
523
+ flat_planes[curdata.side_to_move],
524
+ flat_planes[curdata.rule50_count],
525
+ flat_planes[0],
526
+ flat_planes[1]])
527
+ planes = np.concatenate(planes_stack)
528
+ return planes
529
+
530
+ def lcz_uci_to_idx(self, uci_list):
531
+ # Return list of NN policy output indexes for this board position, given uci_list
532
+
533
+ # TODO: Perhaps it's possible to just add the uci knight promotion move to the index dict
534
+ # currently knight promotions are not in the dict
535
+ uci_list = [uci.rstrip('n') for uci in uci_list]
536
+
537
+ data = self.lcz_stack[-1]
538
+ # uci_to_idx_index =
539
+ # White, no-castling => 0
540
+ # White, castling => 1
541
+ # Black, no-castling => 2
542
+ # Black, castling => 3
543
+ uci_to_idx_index = (data.us_ooo | data.us_oo) + 2*data.side_to_move
544
+ uci_idx_dct = uci_to_idx[uci_to_idx_index]
545
+ return [uci_idx_dct[m] for m in uci_list]
546
+
547
+ @classmethod
548
+ def compress_features(cls, features):
549
+ """Compress a features array as returned from lcz_features method"""
550
+ features_8 = features.astype(np.uint8)
551
+ # Simple compression would do this...
552
+ # return zlib.compress(features_8)
553
+ piece_plane_bytes = np.packbits(features_8[:-8]).tobytes()
554
+ scalar_bytes = features_8[-8:][:,0,0].tobytes()
555
+ compressed = zlib.compress(piece_plane_bytes + scalar_bytes)
556
+ return compressed
557
+
558
+ @classmethod
559
+ def decompress_features(cls, compressed_features):
560
+ """Decompress a compressed features array from compress_features"""
561
+ decompressed = zlib.decompress(compressed_features)
562
+ # Simple decompression would do this
563
+ # return np.frombuffer(decompressed, dtype=np.uint8).astype(np.float32).reshape(-1,8,8)
564
+ piece_plane_bytes = decompressed[:-8]
565
+ scalar_bytes = decompressed[-8:]
566
+ piece_plane_arr = np.unpackbits(memoryview(piece_plane_bytes))
567
+ scalar_arr = np.frombuffer(scalar_bytes, dtype=np.uint8).repeat(64)
568
+ result = np.concatenate((piece_plane_arr, scalar_arr)).astype(np.float32).reshape(-1,8,8)
569
+ return result
570
+
571
+ def unicode(self):
572
+ if self.pc_board.is_game_over() or self.is_draw():
573
+ result = self.pc_board.result(claim_draw=True)
574
+ turnstring = 'Result: {}'.format(result)
575
+ else:
576
+ turnstring = 'Turn: {}'.format('White' if self.pc_board.turn else 'Black')
577
+ boardstr = self.pc_board.unicode() + "\n" + turnstring
578
+ return boardstr
579
+
580
+ def __repr__(self):
581
+ return "LeelaBoard('{}')".format(self.pc_board.fen())
582
+
583
+ def _repr_svg_(self):
584
+ return self.pc_board._repr_svg_()
585
+
586
+ def __str__(self):
587
+ if self.pc_board.is_game_over() or self.is_draw():
588
+ result = self.pc_board.result(claim_draw=True)
589
+ turnstring = 'Result: {}'.format(result)
590
+ else:
591
+ turnstring = 'Turn: {}'.format('White' if self.pc_board.turn else 'Black')
592
+ boardstr = self.pc_board.__str__() + "\n" + turnstring
593
+ return boardstr
594
+
595
+ def __eq__(self, other):
596
+ return self.get_hash_key() == other.get_hash_key()
597
+
598
+ def __hash__(self):
599
+ return hash(self.get_hash_key())
600
+
601
+ def get_hash_key(self):
602
+ transposition_key = self.pc_board._transposition_key()
603
+ return (transposition_key +
604
+ (self._lcz_transposition_counter[transposition_key], self.pc_board.halfmove_clock) +
605
+ tuple(self.pc_board.move_stack[-7:])
606
+ )
607
+
608
+ # lb = LeelaBoard()
609
+ # lb.push_uci('c2c4')
610
+ #lb.push_uci('c7c5')
611
+ #lb.push_uci('d2d3')
612
+ #lb.push_uci('c2c4')
613
+ #lb.push_uci('b8c6')
614
+ # saved_planes = planes
615
+ # planes = lb.features()
616
+ # output = leela_net(torch.from_numpy(planes).unsqueeze(0))
617
+ # output
leela_utils.py ADDED
The diff for this file is too large to render. See raw diff
 
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/leela-large-official.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a439603c5668cd8b05cc9273266a905ac4db689990ca9fe72d0e247102a7c9c3
3
+ size 741143739
python_chess_customized_svg.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file has been copied and slightly modified from python-chess library,
2
+ # Copyright (C) 2016-2020 Niklas Fiekas <niklas.fiekas@backscattering.de>.
3
+
4
+ # This program is free software: you can redistribute it and/or modify
5
+ # it under the terms of the GNU General Public License as published by
6
+ # the Free Software Foundation, either version 3 of the License, or
7
+ # (at your option) any later version.
8
+ #
9
+ # This program is distributed in the hope that it will be useful,
10
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
+ # GNU General Public License for more details.
13
+ #
14
+ # You should have received a copy of the GNU General Public License
15
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
16
+
17
+ # Piece vector graphics are copyright (C) Colin M.L. Burnett
18
+ # <https://en.wikipedia.org/wiki/User:Cburnett> and also licensed under the
19
+ # GNU General Public License.
20
+
21
+ import chess
22
+ import math
23
+
24
+ import xml.etree.ElementTree as ET
25
+
26
+ from typing import Iterable, Optional, Tuple, Union
27
+
28
+ SQUARE_SIZE = 45
29
+ MARGIN = 20
30
+
31
+ PIECES = {
32
+ "b": """<g id="black-bishop" class="black bishop" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zm6-4c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z" fill="#000" stroke-linecap="butt"/><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke="#fff" stroke-linejoin="miter"/></g>""",
33
+ # noqa: E501
34
+ "k": """<g id="black-king" class="black king" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#000" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#000"/><path d="M20 8h5" stroke-linejoin="miter"/><path d="M32 29.5s8.5-4 6.03-9.65C34.15 14 25 18 22.5 24.5l.01 2.1-.01-2.1C20 18 9.906 14 6.997 19.85c-2.497 5.65 4.853 9 4.853 9M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0" stroke="#fff"/></g>""",
35
+ # noqa: E501
36
+ "n": """<g id="black-knight" class="black knight" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#000000; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#000000; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#ececec; stroke:#ececec;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#ececec; stroke:#ececec;"/><path d="M 24.55,10.4 L 24.1,11.85 L 24.6,12 C 27.75,13 30.25,14.49 32.5,18.75 C 34.75,23.01 35.75,29.06 35.25,39 L 35.2,39.5 L 37.45,39.5 L 37.5,39 C 38,28.94 36.62,22.15 34.25,17.66 C 31.88,13.17 28.46,11.02 25.06,10.5 L 24.55,10.4 z " style="fill:#ececec; stroke:none;"/></g>""",
37
+ # noqa: E501
38
+ "p": """<g id="black-pawn" class="black pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" stroke="#000" stroke-width="1.5" stroke-linecap="round"/></g>""",
39
+ # noqa: E501
40
+ "q": """<g id="black-queen" class="black queen" fill="#000" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><g fill="#000" stroke="none"><circle cx="6" cy="12" r="2.75"/><circle cx="14" cy="9" r="2.75"/><circle cx="22.5" cy="8" r="2.75"/><circle cx="31" cy="9" r="2.75"/><circle cx="39" cy="12" r="2.75"/></g><path d="M9 26c8.5-1.5 21-1.5 27 0l2.5-12.5L31 25l-.3-14.1-5.2 13.6-3-14.5-3 14.5-5.2-13.6L14 25 6.5 13.5 9 26zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11 38.5a35 35 1 0 0 23 0" fill="none" stroke-linecap="butt"/><path d="M11 29a35 35 1 0 1 23 0M12.5 31.5h20M11.5 34.5a35 35 1 0 0 22 0M10.5 37.5a35 35 1 0 0 24 0" fill="none" stroke="#fff"/></g>""",
41
+ # noqa: E501
42
+ "r": """<g id="black-rook" class="black rook" fill="#000" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12.5 32l1.5-2.5h17l1.5 2.5h-20zM12 36v-4h21v4H12z" stroke-linecap="butt"/><path d="M14 29.5v-13h17v13H14z" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M14 16.5L11 14h23l-3 2.5H14zM11 14V9h4v2h5V9h5v2h5V9h4v5H11z" stroke-linecap="butt"/><path d="M12 35.5h21M13 31.5h19M14 29.5h17M14 16.5h17M11 14h23" fill="none" stroke="#fff" stroke-width="1" stroke-linejoin="miter"/></g>""",
43
+ # noqa: E501
44
+ "B": """<g id="white-bishop" class="white bishop" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><g fill="#fff" stroke-linecap="butt"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zM15 32c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z"/></g><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke-linejoin="miter"/></g>""",
45
+ # noqa: E501
46
+ "K": """<g id="white-king" class="white king" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6M20 8h5" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#fff" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#fff"/><path d="M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0"/></g>""",
47
+ # noqa: E501
48
+ "N": """<g id="white-knight" class="white knight" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#ffffff; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#ffffff; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#000000; stroke:#000000;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#000000; stroke:#000000;"/></g>""",
49
+ # noqa: E501
50
+ "P": """<g id="white-pawn" class="white pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" fill="#fff" stroke="#000" stroke-width="1.5" stroke-linecap="round"/></g>""",
51
+ # noqa: E501
52
+ "Q": """<g id="white-queen" class="white queen" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M8 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM24.5 7.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM41 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM16 8.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM33 9a2 2 0 1 1-4 0 2 2 0 1 1 4 0z"/><path d="M9 26c8.5-1.5 21-1.5 27 0l2-12-7 11V11l-5.5 13.5-3-15-3 15-5.5-14V25L7 14l2 12zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11.5 30c3.5-1 18.5-1 22 0M12 33.5c6-1 15-1 21 0" fill="none"/></g>""",
53
+ # noqa: E501
54
+ "R": """<g id="white-rook" class="white rook" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12 36v-4h21v4H12zM11 14V9h4v2h5V9h5v2h5V9h4v5" stroke-linecap="butt"/><path d="M34 14l-3 3H14l-3-3"/><path d="M31 17v12.5H14V17" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M31 29.5l1.5 2.5h-20l1.5-2.5"/><path d="M11 14h23" fill="none" stroke-linejoin="miter"/></g>""",
55
+ # noqa: E501
56
+ }
57
+
58
+ PIECES = {
59
+ "b": """<g id="black-bishop" class="black bishop" fill="none" fill-rule="evenodd" stroke="#fff" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zm6-4c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z" fill="#000" stroke-linecap="butt"/><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke="#fff" stroke-linejoin="miter"/></g>""",
60
+ # noqa: E501
61
+ "k": """<g id="black-king" class="black king" fill="none" fill-rule="evenodd" stroke="#fff" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#000" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#000"/><path d="M20 8h5" stroke-linejoin="miter"/><path d="M32 29.5s8.5-4 6.03-9.65C34.15 14 25 18 22.5 24.5l.01 2.1-.01-2.1C20 18 9.906 14 6.997 19.85c-2.497 5.65 4.853 9 4.853 9M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0" stroke="#fff"/></g>""",
62
+ # noqa: E501
63
+ "n": """<g id="black-knight" class="black knight" fill="none" fill-rule="evenodd" stroke="#fff" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#000000; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#000000; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#ececec; stroke:#ececec;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#ececec; stroke:#ececec;"/><path d="M 24.55,10.4 L 24.1,11.85 L 24.6,12 C 27.75,13 30.25,14.49 32.5,18.75 C 34.75,23.01 35.75,29.06 35.25,39 L 35.2,39.5 L 37.45,39.5 L 37.5,39 C 38,28.94 36.62,22.15 34.25,17.66 C 31.88,13.17 28.46,11.02 25.06,10.5 L 24.55,10.4 z " style="fill:#ececec; stroke:none;"/></g>""",
64
+ # noqa: E501
65
+ "p": """<g id="black-pawn" class="black pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" stroke="#fff" stroke-width="1.0" stroke-linecap="round"/></g>""",
66
+ # noqa: E501
67
+ "q": """<g id="black-queen" class="black queen" fill="#000" fill-rule="evenodd" stroke="#000" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><g fill="#000" stroke="none"><circle cx="6" cy="12" r="2.75"/><circle cx="14" cy="9" r="2.75"/><circle cx="22.5" cy="8" r="2.75"/><circle cx="31" cy="9" r="2.75"/><circle cx="39" cy="12" r="2.75"/></g><path d="M9 26c8.5-1.5 21-1.5 27 0l2.5-12.5L31 25l-.3-14.1-5.2 13.6-3-14.5-3 14.5-5.2-13.6L14 25 6.5 13.5 9 26zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11 38.5a35 35 1 0 0 23 0" fill="none" stroke-linecap="butt"/><path d="M11 29a35 35 1 0 1 23 0M12.5 31.5h20M11.5 34.5a35 35 1 0 0 22 0M10.5 37.5a35 35 1 0 0 24 0" fill="none" stroke="#fff"/></g>""",
68
+ # noqa: E501
69
+ "r": """<g id="black-rook" class="black rook" fill="#000" fill-rule="evenodd" stroke="#fff" stroke-width="0.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12.5 32l1.5-2.5h17l1.5 2.5h-20zM12 36v-4h21v4H12z" stroke-linecap="butt"/><path d="M14 29.5v-13h17v13H14z" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M14 16.5L11 14h23l-3 2.5H14zM11 14V9h4v2h5V9h5v2h5V9h4v5H11z" stroke-linecap="butt"/><path d="M12 35.5h21M13 31.5h19M14 29.5h17M14 16.5h17M11 14h23" fill="none" stroke="#fff" stroke-width="1" stroke-linejoin="miter"/></g>""",
70
+ # noqa: E501
71
+ "B": """<g id="white-bishop" class="white bishop" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><g fill="#fff" stroke-linecap="butt"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zM15 32c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z"/></g><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke-linejoin="miter"/></g>""",
72
+ # noqa: E501
73
+ "K": """<g id="white-king" class="white king" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6M20 8h5" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#fff" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#fff"/><path d="M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0"/></g>""",
74
+ # noqa: E501
75
+ "N": """<g id="white-knight" class="white knight" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#ffffff; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#ffffff; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#000000; stroke:#000000;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#000000; stroke:#000000;"/></g>""",
76
+ # noqa: E501
77
+ "P": """<g id="white-pawn" class="white pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" fill="#fff" stroke="#000" stroke-width="1.5" stroke-linecap="round"/></g>""",
78
+ # noqa: E501
79
+ "Q": """<g id="white-queen" class="white queen" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M8 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM24.5 7.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM41 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM16 8.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM33 9a2 2 0 1 1-4 0 2 2 0 1 1 4 0z"/><path d="M9 26c8.5-1.5 21-1.5 27 0l2-12-7 11V11l-5.5 13.5-3-15-3 15-5.5-14V25L7 14l2 12zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11.5 30c3.5-1 18.5-1 22 0M12 33.5c6-1 15-1 21 0" fill="none"/></g>""",
80
+ # noqa: E501
81
+ "R": """<g id="white-rook" class="white rook" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12 36v-4h21v4H12zM11 14V9h4v2h5V9h5v2h5V9h4v5" stroke-linecap="butt"/><path d="M34 14l-3 3H14l-3-3"/><path d="M31 17v12.5H14V17" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M31 29.5l1.5 2.5h-20l1.5-2.5"/><path d="M11 14h23" fill="none" stroke-linejoin="miter"/></g>""",
82
+ # noqa: E501
83
+ }
84
+
85
+ XX = """<g id="xx"><path d="M35.865 9.135a1.89 1.89 0 0 1 0 2.673L25.173 22.5l10.692 10.692a1.89 1.89 0 0 1 0 2.673 1.89 1.89 0 0 1-2.673 0L22.5 25.173 11.808 35.865a1.89 1.89 0 0 1-2.673 0 1.89 1.89 0 0 1 0-2.673L19.827 22.5 9.135 11.808a1.89 1.89 0 0 1 0-2.673 1.89 1.89 0 0 1 2.673 0L22.5 19.827 33.192 9.135a1.89 1.89 0 0 1 2.673 0z" fill="#000" stroke="#fff" stroke-width="1.688"/></g>""" # noqa: E501
86
+
87
+ CHECK_GRADIENT = """<radialGradient id="check_gradient"><stop offset="0%" stop-color="#ff0000" stop-opacity="1.0" /><stop offset="50%" stop-color="#e70000" stop-opacity="1.0" /><stop offset="100%" stop-color="#9e0000" stop-opacity="0.0" /></radialGradient>""" # noqa: E501
88
+
89
+ DEFAULT_COLORS = {
90
+ "square light": "#ffce9e",
91
+ "square dark": "#d18b47",
92
+ "square dark lastmove": "#aaa23b",
93
+ "square light lastmove": "#cdd16a",
94
+ }
95
+
96
+ class Arrow:
97
+ """Details of an arrow to be drawn."""
98
+
99
+ def __init__(self, tail: chess.Square, head: chess.Square, *, color: str = "#888", annotation: str = '') -> None:
100
+ self.tail = tail
101
+ self.head = head
102
+ self.color = color
103
+ self.annotation = annotation
104
+
105
+
106
+ class SvgWrapper(str):
107
+ def _repr_svg_(self) -> "SvgWrapper":
108
+ return self
109
+
110
+
111
+ def _svg(viewbox: int, size: Optional[int]) -> ET.Element:
112
+ svg = ET.Element("svg", {
113
+ "xmlns": "http://www.w3.org/2000/svg",
114
+ "version": "1.1",
115
+ "xmlns:xlink": "http://www.w3.org/1999/xlink",
116
+ "viewBox": f"0 0 {viewbox:d} {viewbox:d}",
117
+ })
118
+
119
+ if size is not None:
120
+ svg.set("width", str(size))
121
+ svg.set("height", str(size))
122
+
123
+ return svg
124
+
125
+
126
+ def _text(content: str, x: int, y: int, width: int, height: int) -> ET.Element:
127
+ t = ET.Element("text", {
128
+ "x": str(x + width // 2),
129
+ "y": str(y + height // 2),
130
+ "font-size": str(max(1, int(min(width, height) * 0.7))),
131
+ "text-anchor": "middle",
132
+ "alignment-baseline": "middle",
133
+ })
134
+ t.text = content
135
+ return t
136
+
137
+
138
+ def piece(piece: chess.Piece, size: Optional[int] = None) -> str:
139
+ """
140
+ Renders the given :class:`chess.Piece` as an SVG image.
141
+ >>> import chess
142
+ >>> import chess.svg
143
+ >>>
144
+ >>> chess.svg.piece(chess.Piece.from_symbol("R")) # doctest: +SKIP
145
+ .. image:: ../docs/wR.svg
146
+ """
147
+ svg = _svg(SQUARE_SIZE, size)
148
+ svg.append(ET.fromstring(PIECES[piece.symbol()]))
149
+ return SvgWrapper(ET.tostring(svg).decode("utf-8"))
150
+
151
+
152
+ def board(board: Optional[chess.BaseBoard] = None, *,
153
+ squares: Optional[chess.IntoSquareSet] = None,
154
+ flipped: bool = False,
155
+ coordinates: bool = True,
156
+ lastmove: Optional[chess.Move] = None,
157
+ check: Optional[chess.Square] = None,
158
+ arrows: Iterable[Union[Arrow, Tuple[chess.Square, chess.Square]]] = (),
159
+ size: Optional[int] = None,
160
+ style: Optional[str] = None,
161
+ square_colors: Iterable[str] = (), #TODO: remove as it is not needed anymore
162
+ only_pieces: bool = False) -> str:
163
+ """
164
+ Renders a board with pieces and/or selected squares as an SVG image.
165
+ :param board: A :class:`chess.BaseBoard` for a chessboard with pieces or
166
+ ``None`` (the default) for a chessboard without pieces.
167
+ :param squares: A :class:`chess.SquareSet` with selected squares.
168
+ :param flipped: Pass ``True`` to flip the board.
169
+ :param coordinates: Pass ``False`` to disable coordinates in the margin.
170
+ :param lastmove: A :class:`chess.Move` to be highlighted.
171
+ :param check: A square to be marked as check.
172
+ :param arrows: A list of :class:`~chess.svg.Arrow` objects like
173
+ ``[chess.svg.Arrow(chess.E2, chess.E4)]`` or a list of tuples like
174
+ ``[(chess.E2, chess.E4)]``. An arrow from a square pointing to the same
175
+ square is drawn as a circle, like ``[(chess.E2, chess.E2)]``.
176
+ :param size: The size of the image in pixels (e.g., ``400`` for a 400 by
177
+ 400 board) or ``None`` (the default) for no size limit.
178
+ :param style: A CSS stylesheet to include in the SVG image.
179
+ >>> import chess
180
+ >>> import chess.svg
181
+ >>>
182
+ >>> board = chess.Board("8/8/8/8/4N3/8/8/8 w - - 0 1")
183
+ >>> squares = board.attacks(chess.E4)
184
+ >>> chess.svg.board(board=board, squares=squares) # doctest: +SKIP
185
+ .. image:: ../docs/Ne4.svg
186
+ """
187
+ margin = MARGIN if coordinates else 0
188
+ svg = _svg(8 * SQUARE_SIZE + 2 * margin, size)
189
+
190
+ if style:
191
+ ET.SubElement(svg, "style").text = style
192
+
193
+ defs = ET.SubElement(svg, "defs")
194
+ if board:
195
+ for piece_color in chess.COLORS:
196
+ for piece_type in chess.PIECE_TYPES:
197
+ if board.pieces_mask(piece_type, piece_color):
198
+ defs.append(ET.fromstring(PIECES[chess.Piece(piece_type, piece_color).symbol()]))
199
+
200
+ squares = chess.SquareSet(squares) if squares else chess.SquareSet()
201
+ if squares:
202
+ defs.append(ET.fromstring(XX))
203
+
204
+ if check is not None and not only_pieces:
205
+ defs.append(ET.fromstring(CHECK_GRADIENT))
206
+
207
+ for square, bb in enumerate(chess.BB_SQUARES):
208
+ file_index = chess.square_file(square)
209
+ rank_index = chess.square_rank(square)
210
+
211
+ x = (file_index if not flipped else 7 - file_index) * SQUARE_SIZE + margin
212
+ y = (7 - rank_index if not flipped else rank_index) * SQUARE_SIZE + margin
213
+
214
+ cls = ["square", "light" if chess.BB_LIGHT_SQUARES & bb else "dark"]
215
+ if lastmove and square in [lastmove.from_square, lastmove.to_square]:
216
+ cls.append("lastmove")
217
+ if square_colors == ():
218
+ fill_color = DEFAULT_COLORS[" ".join(cls)]
219
+ else:
220
+ fill_color = square_colors[square]
221
+
222
+ cls.append(chess.SQUARE_NAMES[square])
223
+ if not only_pieces:
224
+ ET.SubElement(svg, "rect", {
225
+ "x": str(x),
226
+ "y": str(y),
227
+ "width": str(SQUARE_SIZE),
228
+ "height": str(SQUARE_SIZE),
229
+ "class": " ".join(cls),
230
+ "stroke": "none",
231
+ "fill": fill_color,
232
+ })
233
+
234
+ if square == check:
235
+ ET.SubElement(svg, "rect", {
236
+ "x": str(x),
237
+ "y": str(y),
238
+ "width": str(SQUARE_SIZE),
239
+ "height": str(SQUARE_SIZE),
240
+ "class": "check",
241
+ "fill": "url(#check_gradient)",
242
+ })
243
+
244
+ # Render pieces.
245
+ if board is not None:
246
+ piece = board.piece_at(square)
247
+ if piece:
248
+ ET.SubElement(svg, "use", {
249
+ "xlink:href": f"#{chess.COLOR_NAMES[piece.color]}-{chess.PIECE_NAMES[piece.piece_type]}",
250
+ "transform": f"translate({x:d}, {y:d})",
251
+ })
252
+
253
+ # Render selected squares.
254
+ if squares is not None and square in squares:
255
+ #ET.SubElement(svg, "use", {
256
+ # "xlink:href": "#xx",
257
+ # "x": str(x),
258
+ # "y": str(y),
259
+ #})
260
+ ET.SubElement(svg, "rect", {
261
+ "x": str(x),
262
+ "y": str(y),
263
+ "width": str(SQUARE_SIZE),
264
+ "height": str(SQUARE_SIZE),
265
+ "class": "check",
266
+ "fill": "none",
267
+ "stroke": "#FF0000",
268
+ "stroke-width": "5.0",
269
+ "rx": "2.5",
270
+ "opacity": "0.60"
271
+ })
272
+
273
+ if coordinates:
274
+ for file_index, file_name in enumerate(chess.FILE_NAMES):
275
+ x = (file_index if not flipped else 7 - file_index) * SQUARE_SIZE + margin
276
+ svg.append(_text(file_name, x, 0, SQUARE_SIZE, margin))
277
+ svg.append(_text(file_name, x, margin + 8 * SQUARE_SIZE, SQUARE_SIZE, margin))
278
+ for rank_index, rank_name in enumerate(chess.RANK_NAMES):
279
+ y = (7 - rank_index if not flipped else rank_index) * SQUARE_SIZE + margin
280
+ svg.append(_text(rank_name, 0, y, margin, SQUARE_SIZE))
281
+ svg.append(_text(rank_name, margin + 8 * SQUARE_SIZE, y, margin, SQUARE_SIZE))
282
+
283
+ for arrow in arrows:
284
+ try:
285
+ tail, head, color, annotation = arrow.tail, arrow.head, arrow.color, arrow.annotation # type: ignore
286
+ except AttributeError:
287
+ tail, head = arrow # type: ignore
288
+ color = "#888"
289
+ annotation = ''
290
+
291
+ tail_file = chess.square_file(tail)
292
+ tail_rank = chess.square_rank(tail)
293
+ head_file = chess.square_file(head)
294
+ head_rank = chess.square_rank(head)
295
+
296
+ xtail = margin + (tail_file + 0.5 if not flipped else 7.5 - tail_file) * SQUARE_SIZE
297
+ ytail = margin + (7.5 - tail_rank if not flipped else tail_rank + 0.5) * SQUARE_SIZE
298
+ xhead = margin + (head_file + 0.5 if not flipped else 7.5 - head_file) * SQUARE_SIZE
299
+ yhead = margin + (7.5 - head_rank if not flipped else head_rank + 0.5) * SQUARE_SIZE
300
+
301
+ if (head_file, head_rank) == (tail_file, tail_rank):
302
+ ET.SubElement(svg, "circle", {
303
+ "cx": str(xhead),
304
+ "cy": str(yhead),
305
+ "r": str(SQUARE_SIZE * 0.9 / 2),
306
+ "stroke-width": str(SQUARE_SIZE * 0.1),
307
+ "stroke": color,
308
+ "fill": "none",
309
+ "opacity": "0.5",
310
+ "class": "circle",
311
+ })
312
+ else:
313
+ # marker_size = 0.75 * SQUARE_SIZE
314
+ # marker_margin = 0.1 * SQUARE_SIZE
315
+ marker_size = 0.5 * SQUARE_SIZE
316
+ marker_margin = 0.05 * SQUARE_SIZE
317
+
318
+ dx, dy = xhead - xtail, yhead - ytail
319
+ hypot = math.hypot(dx, dy)
320
+
321
+ shaft_x = xhead - dx * (marker_size + marker_margin) / hypot
322
+ shaft_y = yhead - dy * (marker_size + marker_margin) / hypot
323
+
324
+ xtip = xhead - dx * marker_margin / hypot
325
+ ytip = yhead - dy * marker_margin / hypot
326
+
327
+ x_annot = xtail + (shaft_x - xtail) / 2
328
+ y_annot = ytail + (shaft_y - ytail) / 2
329
+
330
+ x_annot = xhead - dx * 0.74 * SQUARE_SIZE / hypot # - (xtip - xtail)*(SQUARE_SIZE/2)
331
+ y_annot = yhead - dy * 0.74 * SQUARE_SIZE / hypot # - (ytip - ytail)*(SQUARE_SIZE/2)
332
+
333
+ ET.SubElement(svg, "line", {
334
+ "x1": str(xtail),
335
+ "y1": str(ytail),
336
+ "x2": str(shaft_x),
337
+ "y2": str(shaft_y),
338
+ "stroke": color,
339
+ "stroke-width": str(SQUARE_SIZE * 0.15),
340
+ "opacity": "0.5",
341
+ "stroke-linecap": "butt",
342
+ "class": "arrow",
343
+ })
344
+
345
+ marker = [(xtip, ytip),
346
+ (shaft_x + dy * 0.5 * marker_size / hypot,
347
+ shaft_y - dx * 0.5 * marker_size / hypot),
348
+ (shaft_x - dy * 0.5 * marker_size / hypot,
349
+ shaft_y + dx * 0.5 * marker_size / hypot)]
350
+
351
+ ET.SubElement(svg, "polygon", {
352
+ "points": " ".join(str(x) + "," + str(y) for x, y in marker),
353
+ "fill": color,
354
+ "opacity": "0.5",
355
+ "class": "arrow",
356
+ })
357
+
358
+ for arrow in arrows:
359
+ try:
360
+ tail, head, color, annotation = arrow.tail, arrow.head, arrow.color, arrow.annotation # type: ignore
361
+ except AttributeError:
362
+ tail, head = arrow # type: ignore
363
+ color = "#888"
364
+ annotation = ''
365
+
366
+ tail_file = chess.square_file(tail)
367
+ tail_rank = chess.square_rank(tail)
368
+ head_file = chess.square_file(head)
369
+ head_rank = chess.square_rank(head)
370
+
371
+ xtail = margin + (tail_file + 0.5 if not flipped else 7.5 - tail_file) * SQUARE_SIZE
372
+ ytail = margin + (7.5 - tail_rank if not flipped else tail_rank + 0.5) * SQUARE_SIZE
373
+ xhead = margin + (head_file + 0.5 if not flipped else 7.5 - head_file) * SQUARE_SIZE
374
+ yhead = margin + (7.5 - head_rank if not flipped else head_rank + 0.5) * SQUARE_SIZE
375
+
376
+ marker_size = 0.5 * SQUARE_SIZE
377
+ marker_margin = 0.05 * SQUARE_SIZE
378
+
379
+ dx, dy = xhead - xtail, yhead - ytail
380
+ hypot = math.hypot(dx, dy)
381
+
382
+ shaft_x = xhead - dx * (marker_size + marker_margin) / hypot
383
+ shaft_y = yhead - dy * (marker_size + marker_margin) / hypot
384
+
385
+ xtip = xhead - dx * marker_margin / hypot
386
+ ytip = yhead - dy * marker_margin / hypot
387
+
388
+ x_annot = xhead - dx * 0.74 * SQUARE_SIZE / hypot
389
+ y_annot = yhead - dy * 0.74 * SQUARE_SIZE / hypot
390
+
391
+ if annotation != '':
392
+ ET.SubElement(svg, "circle", {
393
+ "cx": str(x_annot),
394
+ "cy": str(y_annot),
395
+ "r": str(SQUARE_SIZE * 0.175),
396
+ #"r": str(SQUARE_SIZE * 0.2),
397
+ "stroke-width": str(SQUARE_SIZE * 0.01),
398
+ "stroke": '#000000',
399
+ "fill": color,
400
+ "opacity": "1.0",
401
+ "class": "circle",
402
+ })
403
+ #style = get_style("'BundledDejavuSans'", str(SQUARE_SIZE * 0.1))
404
+ annot = ET.SubElement(svg, "text", {
405
+ "x": str(x_annot),
406
+ "y": str(y_annot),
407
+ "font-size": str(SQUARE_SIZE * 0.2), # max(1, int(min(SQUARE_SIZE, SQUARE_SIZE) * 0.3))),
408
+ "text-anchor": "middle",
409
+ "dominant-baseline": "middle"
410
+ #"alignment-baseline": "middle"
411
+ })
412
+ annot.text = annotation
413
+
414
+ return SvgWrapper(ET.tostring(svg).decode("utf-8"))
svg_pieces.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from python_chess_customized_svg import piece
2
+ import python_chess_customized_svg as svg
3
+ import chess
4
+ import base64
5
+
6
+ #def get_svg_piece(symbol):
7
+ # img = piece(chess.Piece.from_symbol(symbol))
8
+ # svg_str = str(img)
9
+ # svg_byte = svg_str.encode()
10
+ # encoded = base64.b64encode(svg_byte)
11
+ # svg_piece = 'data:image/svg+xml;base64,{}'.format(encoded.decode())
12
+ # return svg_piece
13
+
14
+ def get_svg_board(board, focused_square_ind, only_pieces):
15
+ if focused_square_ind is not None:
16
+ squares = [focused_square_ind]
17
+ else:
18
+ squares = []
19
+ if board.move_stack:
20
+ print('board stack YES')
21
+ lastmove = board.peek()
22
+ else:
23
+ print('board stack NO')
24
+ lastmove = None
25
+ svg_str = str(svg.board(board, squares=squares, arrows=[], lastmove=lastmove, coordinates=False, only_pieces=only_pieces))
26
+ svg_byte = svg_str.encode()
27
+ encoded = base64.b64encode(svg_byte)
28
+ svg_board = 'data:image/svg+xml;base64,{}'.format(encoded.decode())
29
+ return svg_board
30
+
31
+ #SVG_PIECES = {piece: get_svg_piece(piece) for piece in ('b', 'k', 'n', 'p', 'q', 'r', 'B', 'K', 'N', 'P', 'Q', 'R')}
utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from leela_board import LeelaBoard
2
+ import chess
3
+ import torch
4
+
5
+
6
+ def flip_move(move):
7
+ from_square = chess.square_mirror(chess.parse_square(move[:2]))
8
+ to_square = chess.square_mirror(chess.parse_square(move[2:4]))
9
+ promotion = move[4:] if len(move) > 4 else ""
10
+ return chess.square_name(from_square) + chess.square_name(to_square) + promotion
11
+
12
+
13
+ def flip_board(fen, moves):
14
+ temp_board = chess.Board(fen=fen)
15
+ return temp_board.mirror().fen(), [flip_move(move) for move in moves]
16
+
17
+
18
+ # Helper functions
19
+ class ChessBoard:
20
+ def __init__(self, fen): # Create new board from fen
21
+ self.board = LeelaBoard(fen=fen)
22
+ self.t = self.__t()
23
+
24
+ def move(self, move): # Move piece on board ("e2e3")
25
+ self.board.push_uci(move)
26
+ self.t = self.__t()
27
+
28
+ def __t(self): # Set board tensor (private method)
29
+ return torch.from_numpy(self.board.lcz_features()).float()
30
+
31
+ def __str__(self): # Prints board state
32
+ return str(self.board)
visualization_demo.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import marimo
2
+
3
+ __generated_with = "0.8.22"
4
+ app = marimo.App(width="medium")
5
+
6
+
7
+ @app.cell
8
+ def __():
9
+ import marimo as mo
10
+ return (mo,)
11
+
12
+
13
+ @app.cell
14
+ def __():
15
+ import pandas as pd
16
+ df = pd.read_csv("our_visualization/datasets/test_set.csv")
17
+ df.head()
18
+ return df, pd
19
+
20
+
21
+ @app.cell
22
+ def __():
23
+ import pickle
24
+ from utils import ChessBoard
25
+ import onnxruntime as ort
26
+ from leela_board import _idx_to_move_bn, _idx_to_move_wn
27
+ import numpy as np
28
+ from onnx2torch import convert
29
+ import onnx
30
+ import torch
31
+ import os
32
+
33
+ def get_models(root="/Users/sereda/Documents/chessXAI/our_visualization/models"):
34
+ paths = os.listdir(root)
35
+ model_paths = []
36
+ for path in paths:
37
+ if ".onnx" in path: model_paths.append(os.path.join(root, path))
38
+ return model_paths
39
+
40
+ def get_activations_from_model(model_path, pattern, fen):
41
+ # Write hooks for selected model path
42
+ def register_hooks_for_capture(model, pattern):
43
+ activations = {}
44
+ def get_activation(name):
45
+ def hook(module, input, output):
46
+ activations[name] = output.detach().numpy()
47
+ return hook
48
+
49
+ handles = []
50
+ for n, m in model.named_modules():
51
+ if pattern in n:
52
+ handle = m.register_forward_hook(get_activation(n))
53
+ handles.append(handle)
54
+ return activations, handles
55
+
56
+ # Load model and register hooks for it
57
+ model = convert(onnx.load(model_path))
58
+ act, handles = register_hooks_for_capture(model, pattern)
59
+
60
+ # Get fen and pass it through model to generate activations
61
+ board = ChessBoard(fen)
62
+ inputs = board.t
63
+ _, _, _ = model(inputs.unsqueeze(dim=0))
64
+
65
+ # Remove handles
66
+ [h.remove() for h in handles]
67
+ return act
68
+ return (
69
+ ChessBoard,
70
+ convert,
71
+ get_activations_from_model,
72
+ get_models,
73
+ np,
74
+ onnx,
75
+ ort,
76
+ os,
77
+ pickle,
78
+ torch,
79
+ )
80
+
81
+
82
+ @app.cell
83
+ def __(df, mo):
84
+ min_elo, max_elo = df["Rating"].min() // 100 * 100, df["Rating"].max() // 100 * 100
85
+ elo_list = [f"{elo}" for elo in range(min_elo, max_elo + 100, 100)]
86
+ dropdown_elo = mo.ui.dropdown(value = "1000", options=elo_list, label=f"Select rating in range of {min_elo} - {max_elo}")
87
+ dropdown_elo
88
+ return dropdown_elo, elo_list, max_elo, min_elo
89
+
90
+
91
+ @app.cell
92
+ def __(df, dropdown_elo, mo):
93
+ unique_themes = set()
94
+ df_rated = df[(df["Rating"] >= int(dropdown_elo.value)) & (df["Rating"] <= int(dropdown_elo.value) + 100)]
95
+ for i in range(len(df_rated)):
96
+ themes = df_rated.iloc[i]["Themes"].split(" ")
97
+ for theme in themes: unique_themes.add(theme)
98
+ unique_themes_list = list(unique_themes)
99
+ unique_themes_list.sort()
100
+
101
+ dropdown_themes = mo.ui.dropdown(value=unique_themes_list[0], options=unique_themes_list, label=f"Select puzzle theme")
102
+ dropdown_themes
103
+ return (
104
+ df_rated,
105
+ dropdown_themes,
106
+ i,
107
+ theme,
108
+ themes,
109
+ unique_themes,
110
+ unique_themes_list,
111
+ )
112
+
113
+
114
+ @app.cell
115
+ def __(df_rated, dropdown_themes):
116
+ themes_mask = []
117
+ def _(themes_mask):
118
+ for i in range(len(df_rated)):
119
+ themes_new = df_rated.iloc[i]["Themes"].split(" ")
120
+ if dropdown_themes.value in themes_new: themes_mask.append(i)
121
+ _(themes_mask)
122
+ fens = list(df_rated.iloc[themes_mask]["FEN"])
123
+ df_rated.iloc[themes_mask][["FEN", "Moves", "Themes", "Rating"]]
124
+ return fens, themes_mask
125
+
126
+
127
+ @app.cell
128
+ def __(fens, mo):
129
+ dropdown_fen = mo.ui.dropdown(value = fens[0], options=fens, label="Select FEN")
130
+ dropdown_fen
131
+ return (dropdown_fen,)
132
+
133
+
134
+ @app.cell
135
+ def __(df_rated, dropdown_fen, mo):
136
+ moves = df_rated[df_rated["FEN"] == dropdown_fen.value]["Moves"].iloc[0].split(" ")
137
+ player_moves = moves[1::2]
138
+ board_moves = []
139
+ def _(board_moves):
140
+ for i in range(len(player_moves)):
141
+ board_moves.append(moves[:2 * i + 1])
142
+ _(board_moves)
143
+ moves_dict = {pm: om for pm, om in zip(player_moves, board_moves)}
144
+ dropdown_moves = mo.ui.dropdown(options=moves_dict, value=player_moves[0], label="Select which player move to look at")
145
+ # print(moves)
146
+ dropdown_moves
147
+ return board_moves, dropdown_moves, moves, moves_dict, player_moves
148
+
149
+
150
+ @app.cell
151
+ def __(dropdown_moves, mo):
152
+ dropdown_layer = mo.ui.dropdown(value="0", options=[f"{i}" for i in range(15)], label="Select layer (smaller - closer to input)")
153
+ focus_square = mo.ui.text_area(value=dropdown_moves.selected_key[:2], placeholder="Input square to look at (e.g. a1, b8, ...")
154
+ mo.vstack([dropdown_layer, focus_square])
155
+ return dropdown_layer, focus_square
156
+
157
+
158
+ @app.cell
159
+ def __(ChessBoard, dropdown_fen, dropdown_moves):
160
+ def _():
161
+ board = ChessBoard(dropdown_fen.value)
162
+ for move in dropdown_moves.value:
163
+ print(move)
164
+ # board.move(move)
165
+ return board.board.pc_board.fen()
166
+ FEN = _()
167
+ return (FEN,)
168
+
169
+
170
+ @app.cell
171
+ def __(focus_square):
172
+ import chess
173
+ from global_data import global_data
174
+
175
+ focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
176
+
177
+ def set_plotting_parameters(act, layer_number, fen):
178
+ layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
179
+ print(act.keys())
180
+ global_data.model = 'test'
181
+ global_data.activations = act[layer_key][0, :, ::-1 , :]
182
+ print(global_data.activations.shape)
183
+ global_data.subplot_rows = 8
184
+ global_data.subplot_cols = 4
185
+ global_data.board = chess.Board(fen)
186
+ global_data.show_all_heads = True
187
+ # global_data.selected_head = 1
188
+ global_data.visualization_mode = 'ROW'
189
+ global_data.focused_square_ind = focus_square_ind
190
+ # global_data.heatmap_horizontal_gap = 0.001
191
+
192
+ global_data.visualization_mode_is_64x64 = False
193
+ global_data.colorscale_mode = "mode1"
194
+ global_data.show_colorscale = False
195
+ return chess, focus_square_ind, global_data, set_plotting_parameters
196
+
197
+
198
+ @app.cell
199
+ def __(
200
+ FEN,
201
+ dropdown_layer,
202
+ get_activations_from_model,
203
+ get_models,
204
+ set_plotting_parameters,
205
+ ):
206
+ # FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
207
+ # board = ChessBoard("r1b2rk1/pp2pp1p/6p1/3Qb2q/1P4n1/2P1BN2/P2N1PPP/R4RK1 w - - 0 14")
208
+ # board.move("f3e5")
209
+ # FEN = board.board.pc_board.fen()
210
+ PATTERN = "mha/QK/softmax"
211
+ # PATTERN = "smolgen_weights"
212
+ MODEL = get_models()[-1]
213
+ ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN)
214
+ set_plotting_parameters(ACTIVATIONS, int(dropdown_layer.value), FEN)
215
+ from activation_heatmap import heatmap_figure
216
+ fig = heatmap_figure()
217
+ fig.update_layout(height=1500, width=1200)
218
+ fig
219
+ return ACTIVATIONS, MODEL, PATTERN, fig, heatmap_figure
220
+
221
+
222
+ @app.cell
223
+ def __():
224
+ # Add fens after opponents moves
225
+ # Default squares of interest
226
+ return
227
+
228
+
229
+ if __name__ == "__main__":
230
+ app.run()