Xmaster6y commited on
Commit
340463d
1 Parent(s): bd23dfe

working interface

Browse files
src/helpers/generator.py CHANGED
@@ -20,6 +20,7 @@ class OutputGenerator:
20
  self.wrapper = wrapper
21
  self.lens = ActivationLens(module_exp=module_exp)
22
 
 
23
  def generate(
24
  self,
25
  root_fen: Optional[str] = None,
@@ -35,17 +36,16 @@ class OutputGenerator:
35
  input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED
36
  else:
37
  raise ValueError
38
- iter_boards = iter([[root_board, traj_board]])
39
- act_dict, (model_output,) = self.lens.analyse_batched_boards(
40
  iter_boards,
41
  self.wrapper,
42
- {
43
- "return_output": True,
44
- "wrapper_kwargs": {
45
- "input_encoding": input_encoding,
46
- }
47
  }
48
  )
 
49
  if len(act_dict) == 0:
50
  raise ValueError("No module matced the given expression.")
51
  elif len(act_dict) > 1:
 
20
  self.wrapper = wrapper
21
  self.lens = ActivationLens(module_exp=module_exp)
22
 
23
+ @torch.no_grad
24
  def generate(
25
  self,
26
  root_fen: Optional[str] = None,
 
36
  input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED
37
  else:
38
  raise ValueError
39
+ iter_boards = iter([([root_board, traj_board],)])
40
+ result_iter = self.lens.analyse_batched_boards(
41
  iter_boards,
42
  self.wrapper,
43
+ return_output=True,
44
+ wrapper_kwargs={
45
+ "input_encoding": input_encoding,
 
 
46
  }
47
  )
48
+ act_dict, (model_output,) = next(result_iter)
49
  if len(act_dict) == 0:
50
  raise ValueError("No module matced the given expression.")
51
  elif len(act_dict) > 1:
src/interfaces/feature_interface.py CHANGED
@@ -5,6 +5,7 @@ Gradio interface for plotting policy.
5
  import chess
6
  import gradio as gr
7
  import uuid
 
8
 
9
  from lczerolens.encodings import encode_move
10
 
@@ -19,28 +20,57 @@ def compute_features_fn(
19
  traj_fen,
20
  feature_index
21
  ):
22
- model_output, _, sae_output = global_variables.generator.generate(
23
  root_fen=root_fen,
24
  traj_fen=traj_fen
25
  )
26
- features = sae_output["f"]
 
27
  first_output = render_feature_index(
28
  features,
29
  model_output,
30
  file_id,
31
- feature_index,
32
- traj_fen,
33
  )
34
- game_info = f"WDL: {model_output.get('wdl')}"
35
- return *first_output, game_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  def render_feature_index(
39
  features,
40
  model_output,
41
  file_id,
42
- feature_index,
43
- traj_fen,
44
  ):
45
  if file_id is None:
46
  file_id = str(uuid.uuid4())
@@ -98,14 +128,14 @@ with gr.Blocks() as interface:
98
  feature_index = gr.Slider(
99
  label="Feature index",
100
  minimum=0,
101
- maximum=constants.N_FEATURES,
102
  step=1,
103
  value=0,
104
  )
105
 
106
  with gr.Group():
107
  with gr.Row():
108
- game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="")
109
  with gr.Row():
110
  colorbar = gr.Plot(label="Colorbar")
111
  with gr.Column():
@@ -114,8 +144,14 @@ with gr.Blocks() as interface:
114
  features = gr.State(None)
115
  model_output = gr.State(None)
116
  file_id = gr.State(None)
 
117
  compute_features.click(
118
  compute_features_fn,
119
  inputs=[features, model_output, file_id, root_fen, traj_fen, feature_index],
120
- outputs=[features, model_output, file_id, board_image, colorbar, game_info],
121
- )
 
 
 
 
 
 
5
  import chess
6
  import gradio as gr
7
  import uuid
8
+ import torch
9
 
10
  from lczerolens.encodings import encode_move
11
 
 
20
  traj_fen,
21
  feature_index
22
  ):
23
+ model_output, pixel_acts, sae_output = global_variables.generator.generate(
24
  root_fen=root_fen,
25
  traj_fen=traj_fen
26
  )
27
+ features = sae_output["features"]
28
+ x_hat = sae_output["x_hat"]
29
  first_output = render_feature_index(
30
  features,
31
  model_output,
32
  file_id,
33
+ traj_fen,
34
+ feature_index
35
  )
36
+
37
+ half_a_dim = constants.ACTIVATION_DIM // 2
38
+ half_f_dim = constants.DICTIONARY_SIZE // 2
39
+ pixel_f_avg = features.mean(dim=0)
40
+ pixel_f_active = (features > 0).float().mean(dim=0)
41
+ pixel_p_avg = features.mean(dim=1)
42
+ pixel_p_active = (features > 0).float().mean(dim=1)
43
+
44
+ board = chess.Board(traj_fen)
45
+ if board.turn:
46
+ most_avg_pixels = pixel_p_avg.topk(5).indices.tolist()
47
+ most_active_pixels = pixel_p_active.topk(5).indices.tolist()
48
+ else:
49
+ most_avg_pixels = pixel_p_avg.view(8,8).flip(0).view(64).topk(5).indices.tolist()
50
+ most_active_pixels = pixel_p_active.view(8,8).flip(0).view(64).topk(5).indices.tolist()
51
+
52
+ info = f"Root WDL: {model_output['wdl'][0]}\n"
53
+ info += f"Traj WDL: {model_output['wdl'][1]}\n"
54
+ info += f"MSE loss: {torch.nn.functional.mse_loss(x_hat, pixel_acts, reduction='none').sum(dim=1).mean()}\n"
55
+ info += f"MSE loss (root): {torch.nn.functional.mse_loss(x_hat[:,:half_a_dim], pixel_acts[:,:half_a_dim], reduction='none').sum(dim=1).mean()}\n"
56
+ info += f"MSE loss (traj): {torch.nn.functional.mse_loss(x_hat[:,half_a_dim:], pixel_acts[:,half_a_dim:], reduction='none').sum(dim=1).mean()}\n"
57
+ info += f"L0 loss: {(features>0).sum(dim=1).float().mean()}\n"
58
+ info += f"L0 loss (c): {(features[:,:half_f_dim]>0).sum(dim=1).float().mean()}\n"
59
+ info += f"L0 loss (d): {(features[:,half_f_dim:]>0).sum(dim=1).float().mean()}\n"
60
+ info += f"Most active features (avg): {pixel_f_avg.topk(5).indices.tolist()}\n"
61
+ info += f"Most active features (active): {pixel_f_active.topk(5).indices.tolist()}\n"
62
+ info += f"Most active pixels (avg): {[chess.SQUARE_NAMES[p] for p in most_avg_pixels]}\n"
63
+ info += f"Most active pixels (active): {[chess.SQUARE_NAMES[p] for p in most_active_pixels]}"
64
+
65
+ return *first_output, info
66
 
67
 
68
  def render_feature_index(
69
  features,
70
  model_output,
71
  file_id,
72
+ traj_fen,
73
+ feature_index
74
  ):
75
  if file_id is None:
76
  file_id = str(uuid.uuid4())
 
128
  feature_index = gr.Slider(
129
  label="Feature index",
130
  minimum=0,
131
+ maximum=constants.DICTIONARY_SIZE-1,
132
  step=1,
133
  value=0,
134
  )
135
 
136
  with gr.Group():
137
  with gr.Row():
138
+ info = gr.Textbox(label="Info", lines=1, max_lines=20, value="")
139
  with gr.Row():
140
  colorbar = gr.Plot(label="Colorbar")
141
  with gr.Column():
 
144
  features = gr.State(None)
145
  model_output = gr.State(None)
146
  file_id = gr.State(None)
147
+
148
  compute_features.click(
149
  compute_features_fn,
150
  inputs=[features, model_output, file_id, root_fen, traj_fen, feature_index],
151
+ outputs=[features, model_output, file_id, board_image, colorbar, info],
152
+ )
153
+ feature_index.change(
154
+ render_feature_index,
155
+ inputs=[features, model_output, file_id, traj_fen, feature_index],
156
+ outputs=[features, model_output, file_id, board_image, colorbar],
157
+ )