Xmaster6y commited on
Commit
3b6ef01
1 Parent(s): 340463d

better demo

Browse files
app.py CHANGED
@@ -4,15 +4,19 @@ Main Gradio module.
4
 
5
  import gradio as gr
6
 
7
- from src.interfaces import feature_interface
8
 
9
 
10
  demo = gr.TabbedInterface(
11
  [
12
- feature_interface,
 
 
13
  ],
14
  [
15
- "Feature Activation",
 
 
16
  ],
17
  title="Lczero Planning Demo",
18
  analytics_enabled=False,
 
4
 
5
  import gradio as gr
6
 
7
+ from src.interfaces import fen_feature_interface, game_feature_interface, act_max_interface
8
 
9
 
10
  demo = gr.TabbedInterface(
11
  [
12
+ fen_feature_interface.interface,
13
+ game_feature_interface.interface,
14
+ act_max_interface.interface,
15
  ],
16
  [
17
+ "Feature Activation (FEN)",
18
+ "Feature Activation (Game)",
19
+ "Feature Activation Maximisation",
20
  ],
21
  title="Lczero Planning Demo",
22
  analytics_enabled=False,
src/constants.py CHANGED
@@ -18,4 +18,5 @@ LAYER = 9
18
  ACTIVATION_DIM = 256
19
  DICTIONARY_SIZE = 7680
20
  PRE_BIAS = False
21
- INIT_NORMALISE_DICT = None
 
 
18
  ACTIVATION_DIM = 256
19
  DICTIONARY_SIZE = 7680
20
  PRE_BIAS = False
21
+ INIT_NORMALISE_DICT = None
22
+ FEATURE_DATASET = "Xmaster6y/lczero-planning-features"
src/global_variables.py CHANGED
@@ -6,6 +6,7 @@ from huggingface_hub import HfApi
6
  import gradio as gr
7
  from lczerolens import ModelWrapper
8
  import torch
 
9
 
10
  from src import constants
11
  from src.helpers import SparseAutoEncoder, OutputGenerator
@@ -14,6 +15,7 @@ hf_api: HfApi
14
  wrapper: ModelWrapper
15
  sae: SparseAutoEncoder
16
  generator: OutputGenerator
 
17
 
18
 
19
  def setup():
@@ -21,6 +23,7 @@ def setup():
21
  global wrapper
22
  global sae
23
  global generator
 
24
 
25
  hf_api = HfApi(token=constants.HF_TOKEN)
26
  hf_api.snapshot_download(
@@ -53,6 +56,11 @@ def setup():
53
  wrapper=wrapper,
54
  module_exp=rf".*block{constants.LAYER}/conv2/relu"
55
  )
 
 
 
 
 
56
 
57
  if gr.NO_RELOAD:
58
  setup()
 
6
  import gradio as gr
7
  from lczerolens import ModelWrapper
8
  import torch
9
+ from datasets import load_dataset, Dataset
10
 
11
  from src import constants
12
  from src.helpers import SparseAutoEncoder, OutputGenerator
 
15
  wrapper: ModelWrapper
16
  sae: SparseAutoEncoder
17
  generator: OutputGenerator
18
+ f_ds: Dataset
19
 
20
 
21
  def setup():
 
23
  global wrapper
24
  global sae
25
  global generator
26
+ global f_ds
27
 
28
  hf_api = HfApi(token=constants.HF_TOKEN)
29
  hf_api.snapshot_download(
 
56
  wrapper=wrapper,
57
  module_exp=rf".*block{constants.LAYER}/conv2/relu"
58
  )
59
+ f_ds = load_dataset(
60
+ constants.FEATURE_DATASET,
61
+ constants.SAE_CONFIG,
62
+ split="test"
63
+ ).with_format("torch")
64
 
65
  if gr.NO_RELOAD:
66
  setup()
src/interfaces/__init__.py CHANGED
@@ -1,2 +0,0 @@
1
-
2
- from .feature_interface import interface as feature_interface
 
 
 
src/interfaces/act_max_interface.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting policy.
3
+ """
4
+
5
+ import chess
6
+ import gradio as gr
7
+ import uuid
8
+ import torch
9
+
10
+ from lczerolens.encodings import encode_move
11
+
12
+ from src import constants, global_variables, visualisation
13
+
14
+
15
+ def render_feature_index(
16
+ file_id,
17
+ feature_index
18
+ ):
19
+ if file_id is None:
20
+ file_id = str(uuid.uuid4())
21
+ opt_features = global_variables.f_ds["opt_features"]
22
+ f_acts = opt_features[:, feature_index]
23
+ indices = f_acts.topk(16).indices
24
+ board_images = []
25
+ colorbars = []
26
+ for topi, idx in enumerate(indices):
27
+ s = global_variables.f_ds[idx.item()]
28
+ pixel_index = global_variables.f_ds["pixel_index"][idx]
29
+ features = []
30
+ for i in range(64):
31
+ current_index = idx + i - pixel_index
32
+ features.append(opt_features[current_index.item(), feature_index])
33
+ features = torch.stack(features)
34
+
35
+ fen = s["opt_fen"]
36
+ current_depth = s["current_depth"]
37
+ uci_move = s["moves_opt"][current_depth + 6]
38
+ move = chess.Move.from_uci(uci_move)
39
+ board = chess.Board(fen)
40
+ if board.turn:
41
+ heatmap = features.view(64)
42
+ else:
43
+ heatmap = features.view(8, 8).flip(0).view(64)
44
+ svg_board, fig = visualisation.render_heatmap(
45
+ board,
46
+ heatmap,
47
+ arrows=[(move.from_square, move.to_square)],
48
+ )
49
+ with open(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg", "w") as f:
50
+ f.write(svg_board)
51
+ board_images.append(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg")
52
+ colorbars.append(fig)
53
+ return file_id, *board_images, *colorbars
54
+
55
+
56
+
57
+ with gr.Blocks() as interface:
58
+ with gr.Row():
59
+ feature_index = gr.Slider(
60
+ label="Feature index",
61
+ minimum=0,
62
+ maximum=constants.DICTIONARY_SIZE-1,
63
+ step=1,
64
+ value=0,
65
+ )
66
+ board_images = []
67
+ colorbars = []
68
+ for i in range(4):
69
+ with gr.Row():
70
+ for j in range(4):
71
+ with gr.Column():
72
+ with gr.Group():
73
+ idx = 4*i + j
74
+ with gr.Row():
75
+ board_images.append(gr.Image(label=f"Board {idx}"))
76
+ with gr.Row():
77
+ colorbars.append(gr.Plot(label=f"Colorbar {idx}"))
78
+
79
+ file_id = gr.State(None)
80
+ feature_index.change(
81
+ render_feature_index,
82
+ inputs=[file_id, feature_index],
83
+ outputs=[file_id, *board_images, *colorbars],
84
+ )
src/interfaces/{feature_interface.py → fen_feature_interface.py} RENAMED
File without changes
src/interfaces/game_feature_interface.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting policy.
3
+ """
4
+
5
+ import chess
6
+ import gradio as gr
7
+ import uuid
8
+ import torch
9
+
10
+ from lczerolens.encodings import encode_move
11
+
12
+ from src import constants, global_variables, visualisation
13
+
14
+
15
+ def compute_features_fn(
16
+ features,
17
+ model_output,
18
+ file_id,
19
+ root_idx,
20
+ traj_idx,
21
+ start_fen,
22
+ move_seq,
23
+ feature_index
24
+ ):
25
+ error_return = [features, model_output, file_id, root_idx, traj_idx] + [None] * 5
26
+ root_board = None
27
+ traj_board = None
28
+ try:
29
+ board = chess.Board(start_fen)
30
+ except ValueError:
31
+ board = chess.Board()
32
+ gr.Warning("Invalid FEN, using starting position.")
33
+ return error_return
34
+ i = 0
35
+ if root_idx == 0:
36
+ root_board = board.copy()
37
+ if traj_idx == 0:
38
+ traj_board = board.copy()
39
+ if move_seq:
40
+ try:
41
+ if move_seq.startswith("1."):
42
+ for move in move_seq.split():
43
+ if root_board is not None and traj_board is not None:
44
+ break
45
+ if move.endswith("."):
46
+ continue
47
+ board.push_san(move)
48
+ i += 1
49
+ if i == root_idx:
50
+ root_board = board.copy()
51
+ if i == traj_idx:
52
+ traj_board = board.copy()
53
+ else:
54
+ for move in move_seq.split():
55
+ if root_board is not None and traj_board is not None:
56
+ break
57
+ board.push_uci(move)
58
+ i += 1
59
+ if i == root_idx:
60
+ root_board = board.copy()
61
+ if i == traj_idx:
62
+ traj_board = board.copy()
63
+ except ValueError:
64
+ gr.Warning(f"Invalid move {move}.")
65
+ return error_return
66
+ if root_board is None or traj_board is None:
67
+ gr.Warning("Invalid move sequence.")
68
+ return error_return
69
+
70
+ model_output, pixel_acts, sae_output = global_variables.generator.generate(
71
+ root_board=root_board,
72
+ traj_board=traj_board
73
+ )
74
+ current_root_fen = root_board.fen()
75
+ current_traj_fen = traj_board.fen()
76
+ features = sae_output["features"]
77
+ x_hat = sae_output["x_hat"]
78
+ first_output = render_feature_index(
79
+ features,
80
+ model_output,
81
+ file_id,
82
+ root_idx,
83
+ traj_idx,
84
+ current_traj_fen,
85
+ feature_index
86
+ )
87
+
88
+ half_a_dim = constants.ACTIVATION_DIM // 2
89
+ half_f_dim = constants.DICTIONARY_SIZE // 2
90
+ pixel_f_avg = features.mean(dim=0)
91
+ pixel_f_active = (features > 0).float().mean(dim=0)
92
+ pixel_p_avg = features.mean(dim=1)
93
+ pixel_p_active = (features > 0).float().mean(dim=1)
94
+
95
+ if board.turn:
96
+ most_avg_pixels = pixel_p_avg.topk(5).indices.tolist()
97
+ most_active_pixels = pixel_p_active.topk(5).indices.tolist()
98
+ else:
99
+ most_avg_pixels = pixel_p_avg.view(8,8).flip(0).view(64).topk(5).indices.tolist()
100
+ most_active_pixels = pixel_p_active.view(8,8).flip(0).view(64).topk(5).indices.tolist()
101
+
102
+ info = f"Root WDL: {model_output['wdl'][0]}\n"
103
+ info += f"Traj WDL: {model_output['wdl'][1]}\n"
104
+ info += f"MSE loss: {torch.nn.functional.mse_loss(x_hat, pixel_acts, reduction='none').sum(dim=1).mean()}\n"
105
+ info += f"MSE loss (root): {torch.nn.functional.mse_loss(x_hat[:,:half_a_dim], pixel_acts[:,:half_a_dim], reduction='none').sum(dim=1).mean()}\n"
106
+ info += f"MSE loss (traj): {torch.nn.functional.mse_loss(x_hat[:,half_a_dim:], pixel_acts[:,half_a_dim:], reduction='none').sum(dim=1).mean()}\n"
107
+ info += f"L0 loss: {(features>0).sum(dim=1).float().mean()}\n"
108
+ info += f"L0 loss (c): {(features[:,:half_f_dim]>0).sum(dim=1).float().mean()}\n"
109
+ info += f"L0 loss (d): {(features[:,half_f_dim:]>0).sum(dim=1).float().mean()}\n"
110
+ info += f"Most active features (avg): {pixel_f_avg.topk(5).indices.tolist()}\n"
111
+ info += f"Most active features (active): {pixel_f_active.topk(5).indices.tolist()}\n"
112
+ info += f"Most active pixels (avg): {[chess.SQUARE_NAMES[p] for p in most_avg_pixels]}\n"
113
+ info += f"Most active pixels (active): {[chess.SQUARE_NAMES[p] for p in most_active_pixels]}"
114
+
115
+ return *first_output, current_root_fen, current_traj_fen, info
116
+
117
+
118
+ def render_feature_index(
119
+ features,
120
+ model_output,
121
+ file_id,
122
+ root_idx,
123
+ traj_idx,
124
+ traj_fen,
125
+ feature_index,
126
+ ):
127
+ if file_id is None:
128
+ file_id = str(uuid.uuid4())
129
+ board = chess.Board(traj_fen)
130
+ pixel_features = features[:,feature_index]
131
+ if board.turn:
132
+ heatmap = pixel_features.view(64)
133
+ else:
134
+ heatmap = pixel_features.view(8,8).flip(0).view(64)
135
+
136
+ best_legal_logit = None
137
+ best_legal_move = None
138
+ for move in board.legal_moves:
139
+ move_index = encode_move(move, (board.turn, not board.turn))
140
+ logit = model_output["policy"][1,move_index].item()
141
+ if best_legal_logit is None:
142
+ best_legal_logit = logit
143
+ else:
144
+ best_legal_move = move
145
+
146
+ svg_board, fig = visualisation.render_heatmap(
147
+ board,
148
+ heatmap,
149
+ arrows=[(best_legal_move.from_square, best_legal_move.to_square)],
150
+ )
151
+ with open(f"{constants.FIGURES_FOLER}/{file_id}.svg", "w") as f:
152
+ f.write(svg_board)
153
+ return (
154
+ features,
155
+ model_output,
156
+ file_id,
157
+ root_idx,
158
+ traj_idx,
159
+ f"{constants.FIGURES_FOLER}/{file_id}.svg",
160
+ fig
161
+ )
162
+
163
+ def make_features_fn(var, direction):
164
+ def _make_features_fn(
165
+ features,
166
+ model_output,
167
+ file_id,
168
+ root_idx,
169
+ traj_idx,
170
+ start_fen,
171
+ move_seq,
172
+ feature_index
173
+ ):
174
+ move_count = len([mv for mv in move_seq.split() if not mv.endswith(".")])
175
+ if var == "root":
176
+ root_idx += direction
177
+ if root_idx < 0:
178
+ gr.Warning("Already at first board.")
179
+ root_idx = 0
180
+ elif root_idx >= move_count:
181
+ gr.Warning("Already at last board.")
182
+ root_idx = move_count - 1
183
+ elif root_idx > traj_idx:
184
+ gr.Warning("Root should be before traj.")
185
+ root_idx = traj_idx
186
+ elif var == "traj":
187
+ traj_idx += direction
188
+ if traj_idx < 0:
189
+ gr.Warning("Already at first board.")
190
+ traj_idx = 0
191
+ elif traj_idx >= move_count:
192
+ gr.Warning("Already at last board.")
193
+ traj_idx = move_count - 1
194
+ elif traj_idx < root_idx:
195
+ gr.Warning("Traj should be after root.")
196
+ traj_idx = root_idx
197
+ return compute_features_fn(
198
+ features,
199
+ model_output,
200
+ file_id,
201
+ root_idx,
202
+ traj_idx,
203
+ start_fen,
204
+ move_seq,
205
+ feature_index
206
+ )
207
+ return _make_features_fn
208
+
209
+ with gr.Blocks() as interface:
210
+ with gr.Row():
211
+ with gr.Column():
212
+ start_fen = gr.Textbox(
213
+ label="Starting FEN",
214
+ lines=1,
215
+ max_lines=1,
216
+ value=chess.STARTING_FEN,
217
+ )
218
+ move_seq = gr.Textbox(
219
+ label="Move sequence",
220
+ lines=1,
221
+ max_lines=1,
222
+ value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
223
+ )
224
+
225
+ with gr.Group():
226
+ with gr.Row():
227
+ previous_root_button = gr.Button("Previous root")
228
+ next_root_button = gr.Button("Next root")
229
+
230
+ with gr.Row():
231
+ previous_traj_button = gr.Button("Previous traj")
232
+ next_traj_button = gr.Button("Next traj")
233
+
234
+ with gr.Group():
235
+ with gr.Row():
236
+ current_root_fen = gr.Textbox(
237
+ label="Root FEN",
238
+ lines=1,
239
+ max_lines=1,
240
+ interactive=False
241
+ )
242
+ with gr.Row():
243
+ current_traj_fen = gr.Textbox(
244
+ label="Traj FEN",
245
+ lines=1,
246
+ max_lines=1,
247
+ interactive=False
248
+ )
249
+ with gr.Row():
250
+ feature_index = gr.Slider(
251
+ label="Feature index",
252
+ minimum=0,
253
+ maximum=constants.DICTIONARY_SIZE-1,
254
+ step=1,
255
+ value=0,
256
+ )
257
+
258
+ with gr.Group():
259
+ with gr.Row():
260
+ info = gr.Textbox(label="Info", lines=1, max_lines=20, value="")
261
+ with gr.Row():
262
+ colorbar = gr.Plot(label="Colorbar")
263
+ with gr.Column():
264
+ board_image = gr.Image(label="Board")
265
+
266
+ features = gr.State(None)
267
+ model_output = gr.State(None)
268
+ file_id = gr.State(None)
269
+ root_idx = gr.State(0)
270
+ traj_idx = gr.State(0)
271
+ state = [features, model_output, file_id, root_idx, traj_idx]
272
+
273
+ base_inputs = [start_fen, move_seq, feature_index]
274
+ base_outputs = [board_image, colorbar, current_root_fen, current_traj_fen, info]
275
+
276
+ previous_root_button.click(
277
+ make_features_fn(var="root", direction=-1),
278
+ inputs=state + base_inputs,
279
+ outputs=state + base_outputs,
280
+ )
281
+ next_root_button.click(
282
+ make_features_fn(var="root", direction=1),
283
+ inputs=state + base_inputs,
284
+ outputs=state + base_outputs,
285
+ )
286
+ previous_traj_button.click(
287
+ make_features_fn(var="traj", direction=-1),
288
+ inputs=state + base_inputs,
289
+ outputs=state + base_outputs,
290
+ )
291
+ next_traj_button.click(
292
+ make_features_fn(var="traj", direction=1),
293
+ inputs=state + base_inputs,
294
+ outputs=state + base_outputs,
295
+ )
296
+ feature_index.change(
297
+ render_feature_index,
298
+ inputs=state + [current_traj_fen, feature_index],
299
+ outputs=state + [board_image, colorbar],
300
+ )
src/interfaces/stats_interface.py DELETED
File without changes