Xmaster6y commited on
Commit
0d998a6
1 Parent(s): 41a0620

new repo structure

Browse files
.gitignore ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pipenv
85
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
86
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
87
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
88
+ # install all needed dependencies.
89
+ #Pipfile.lock
90
+
91
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
92
+ __pypackages__/
93
+
94
+ # Celery stuff
95
+ celerybeat-schedule
96
+ celerybeat.pid
97
+
98
+ # SageMath parsed files
99
+ *.sage.py
100
+
101
+ # Environments
102
+ .env
103
+ .venv
104
+ env/
105
+ venv/
106
+ ENV/
107
+ env.bak/
108
+ venv.bak/
109
+
110
+ # Spyder project settings
111
+ .spyderproject
112
+ .spyproject
113
+
114
+ # Rope project settings
115
+ .ropeproject
116
+
117
+ # mkdocs documentation
118
+ /site
119
+
120
+ # mypy
121
+ .mypy_cache/
122
+ .dmypy.json
123
+ dmypy.json
124
+
125
+ # Pyre type checker
126
+ .pyre/
127
+
128
+ # Pickle files
129
+ *.pkl
130
+
131
+ # Various files
132
+ ignored
133
+ debug
134
+ *.zip
135
+ lc0
136
+ !bin/lc0
137
+ wandb
138
+
139
+ *secret*
app.py CHANGED
@@ -4,11 +4,15 @@ Main Gradio module.
4
 
5
  import gradio as gr
6
 
 
 
7
 
8
  demo = gr.TabbedInterface(
9
  [
 
10
  ],
11
  [
 
12
  ],
13
  title="Lczero Planning Demo",
14
  analytics_enabled=False,
 
4
 
5
  import gradio as gr
6
 
7
+ from src.interfaces import feature_interface
8
+
9
 
10
  demo = gr.TabbedInterface(
11
  [
12
+ feature_interface,
13
  ],
14
  [
15
+ "Feature Activation",
16
  ],
17
  title="Lczero Planning Demo",
18
  analytics_enabled=False,
assets/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
figures/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ git+https://github.com/Xmaster6y/lczerolens
2
+ chess
3
+ matplotlib
4
+ numpy
5
+ torch
6
+ tensordict
7
+ einops
src/constants.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Manage constants for the app.
2
+ """
3
+
4
+ import os
5
+ import pathlib
6
+ import torch
7
+
8
+
9
+ ASSETS_FOLDER = pathlib.Path(__file__).parent.parent / "assets"
10
+ FIGURES_FOLER = pathlib.Path(__file__).parent.parent / "figures"
11
+ HF_TOKEN = os.getenv("HF_TOKEN")
12
+
13
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ MODEL_NAME = "lc0-10-4238.onnx"
16
+ SAE_CONFIG = "debug"
17
+ LAYER = 9
18
+ N_FEATURES = 7680
src/global_variables.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Manage global variables for the app.
2
+ """
3
+
4
+ from huggingface_hub import HfApi
5
+
6
+ import gradio as gr
7
+ from lczerolens import ModelWrapper
8
+ import torch
9
+
10
+ from src.constants import HF_TOKEN, ASSETS_FOLDER, DEVICE, MODEL_NAME, SAE_CONFIG, LAYER
11
+ from src.helpers import SparseAutoEncoder, OutputGenerator
12
+
13
+ hf_api: HfApi
14
+ wrapper: ModelWrapper
15
+ sae: SparseAutoEncoder
16
+ generator: OutputGenerator
17
+
18
+
19
+ def setup():
20
+ global hf_api
21
+ global wrapper
22
+ global sae
23
+ global generator
24
+
25
+ hf_api = HfApi(token=HF_TOKEN)
26
+ hf_api.snapshot_download(
27
+ local_dir=f"{ASSETS_FOLDER}/models",
28
+ repo_id="Xmaster6y/lczero-planning-models",
29
+ repo_type="model",
30
+ )
31
+ hf_api.snapshot_download(
32
+ local_dir=f"{ASSETS_FOLDER}/saes",
33
+ repo_id="Xmaster6y/lczero-planning-saes",
34
+ repo_type="model",
35
+ )
36
+
37
+ wrapper = ModelWrapper.from_onnx_path(f"{ASSETS_FOLDER}/models/{MODEL_NAME}").to(DEVICE)
38
+ sae_dict = torch.load(
39
+ f"{ASSETS_FOLDER}/saes/{SAE_CONFIG}/model.pt",
40
+ map_location=DEVICE,
41
+ weights_only=True
42
+ )
43
+ sae = SparseAutoEncoder()
44
+ sae.load_state_dict(
45
+ sae_dict
46
+ )
47
+ generator = OutputGenerator(
48
+ sae=sae,
49
+ wrapper=wrapper,
50
+ module_exp=rf".*block{LAYER}/conv2/relu"
51
+ )
52
+
53
+ if gr.NO_RELOAD:
54
+ setup()
src/helpers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+
3
+ from .generator import OutputGenerator
4
+ from .sae import SparseAutoEncoder
src/helpers/generator.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to generate features for a given board state.
2
+ """
3
+
4
+ from typing import Optional
5
+
6
+ from lczerolens import ModelWrapper
7
+ from lczerolens.xai import ActivationLens
8
+ from lczerolens.encodings import InputEncoding
9
+ import chess
10
+ import einops
11
+ import torch
12
+
13
+ from .sae import SparseAutoEncoder
14
+
15
+
16
+ class OutputGenerator:
17
+
18
+ def __init__(self, sae: SparseAutoEncoder, wrapper: ModelWrapper, module_exp: Optional[str] = None):
19
+ self.sae = sae
20
+ self.wrapper = wrapper
21
+ self.lens = ActivationLens(module_exp=module_exp)
22
+
23
+ def generate(
24
+ self,
25
+ root_fen: Optional[str] = None,
26
+ traj_fen: Optional[str] = None,
27
+ root_board: Optional[chess.Board] = None,
28
+ traj_board: Optional[chess.Board] = None,
29
+ ):
30
+ if root_board is not None and traj_board is not None:
31
+ input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE
32
+ elif root_fen is not None and traj_fen is not None:
33
+ root_board = chess.Board(root_fen)
34
+ traj_board = chess.Board(traj_fen)
35
+ input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED
36
+ else:
37
+ raise ValueError
38
+ iter_boards = iter([[root_board, traj_board]])
39
+ act_dict, (model_output,) = self.lens.analyse_batched_boards(
40
+ iter_boards,
41
+ self.wrapper,
42
+ {
43
+ "return_output": True,
44
+ "wrapper_kwargs": {
45
+ "input_encoding": input_encoding,
46
+ }
47
+ }
48
+ )
49
+ if len(act_dict) == 0:
50
+ raise ValueError("No module matced the given expression.")
51
+ elif len(act_dict) > 1:
52
+ raise ValueError("Multiple modules matched the given expression.")
53
+ acts = next(iter(act_dict.values()))
54
+ root_acts = einops.rearrange(acts[0], "c h w -> (h w) c")
55
+ traj_acts = einops.rearrange(acts[1], "c h w -> (h w) c")
56
+ pixel_acts = torch.cat([root_acts, traj_acts], dim=1)
57
+ sae_output = self.sae(pixel_acts, output_features=True)
58
+ return model_output, pixel_acts, sae_output
59
+
src/helpers/sae.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines the dictionary classes
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from tensordict import TensorDict
8
+
9
+
10
+ class SparseAutoEncoder(nn.Module):
11
+ """
12
+ A 2-layer sparse autoencoder.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ activation_dim,
18
+ dict_size,
19
+ pre_bias=False,
20
+ init_normalise_dict=None,
21
+ ):
22
+ super().__init__()
23
+ self.activation_dim = activation_dim
24
+ self.dict_size = dict_size
25
+ self.pre_bias = pre_bias
26
+ self.init_normalise_dict = init_normalise_dict
27
+
28
+ self.b_enc = nn.Parameter(torch.zeros(self.dict_size))
29
+ self.relu = nn.ReLU()
30
+
31
+ self.W_dec = nn.Parameter(
32
+ torch.nn.init.kaiming_uniform_(
33
+ torch.empty(
34
+ self.dict_size,
35
+ self.activation_dim,
36
+ )
37
+ )
38
+ )
39
+ if init_normalise_dict == "l2":
40
+ self.normalize_dict_(less_than_1=False)
41
+ self.W_dec *= 0.1
42
+ elif init_normalise_dict == "less_than_1":
43
+ self.normalize_dict_(less_than_1=True)
44
+
45
+ self.W_enc = nn.Parameter(self.W_dec.t())
46
+ self.b_dec = nn.Parameter(
47
+ torch.zeros(
48
+ self.activation_dim,
49
+ )
50
+ )
51
+
52
+ @torch.no_grad()
53
+ def normalize_dict_(
54
+ self,
55
+ less_than_1=False,
56
+ ):
57
+ norm = self.W_dec.norm(dim=1)
58
+ positive_mask = norm != 0
59
+ if less_than_1:
60
+ greater_than_1_mask = (norm > 1) & (positive_mask)
61
+ self.W_dec[greater_than_1_mask] /= norm[greater_than_1_mask].unsqueeze(1)
62
+ else:
63
+ self.W_dec[positive_mask] /= norm[positive_mask].unsqueeze(1)
64
+
65
+ def encode(self, x):
66
+ return x @ self.W_enc + self.b_enc
67
+
68
+ def decode(self, f):
69
+ return f @ self.W_dec + self.b_dec
70
+
71
+ def forward(self, x, output_features=False, ghost_mask=None):
72
+ """
73
+ Forward pass of an autoencoder.
74
+ x : activations to be autoencoded
75
+ output_features : if True, return the encoded features as well
76
+ as the decoded x
77
+ ghost_mask : if not None, run this autoencoder in "ghost mode"
78
+ where features are masked
79
+ """
80
+ if self.pre_bias:
81
+ x = x - self.b_dec
82
+ f_pre = self.encode(x)
83
+ out = TensorDict({}, batch_size=x.shape[0])
84
+ if ghost_mask is not None:
85
+ f_ghost = torch.exp(f_pre) * ghost_mask.to(f_pre)
86
+ x_ghost = f_ghost @ self.W_dec
87
+ out["x_ghost"] = x_ghost
88
+ f = self.relu(f_pre)
89
+ if output_features:
90
+ out["features"] = f
91
+ x_hat = self.decode(f)
92
+ out["x_hat"] = x_hat
93
+ return out
src/interfaces/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ from .feature_interface import interface as feature_interface
src/interfaces/feature_interface.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting policy.
3
+ """
4
+
5
+ import chess
6
+ import gradio as gr
7
+ import uuid
8
+
9
+ from lczerolens.encodings import encode_move
10
+
11
+ from src import constants, global_variables, visualisation
12
+
13
+
14
+ def compute_features_fn(
15
+ features,
16
+ model_output,
17
+ file_id,
18
+ root_fen,
19
+ traj_fen,
20
+ feature_index
21
+ ):
22
+ model_output, _, sae_output = global_variables.generator.generate(
23
+ root_fen=root_fen,
24
+ traj_fen=traj_fen
25
+ )
26
+ features = sae_output["f"]
27
+ first_output = render_feature_index(
28
+ features,
29
+ model_output,
30
+ file_id,
31
+ feature_index,
32
+ traj_fen,
33
+ )
34
+ game_info = f"WDL: {model_output.get('wdl')}"
35
+ return *first_output, game_info
36
+
37
+
38
+ def render_feature_index(
39
+ features,
40
+ model_output,
41
+ file_id,
42
+ feature_index,
43
+ traj_fen,
44
+ ):
45
+ if file_id is None:
46
+ file_id = str(uuid.uuid4())
47
+ board = chess.Board(traj_fen)
48
+ pixel_features = features[:,feature_index]
49
+ if board.turn:
50
+ heatmap = pixel_features.view(64)
51
+ else:
52
+ heatmap = pixel_features.view(8,8).flip(0).view(64)
53
+
54
+ best_legal_logit = None
55
+ best_legal_move = None
56
+ for move in board.legal_moves:
57
+ move_index = encode_move(move, (board.turn, not board.turn))
58
+ logit = model_output["policy"][1,move_index].item()
59
+ if best_legal_logit is None:
60
+ best_legal_logit = logit
61
+ else:
62
+ best_legal_move = move
63
+
64
+ svg_board, fig = visualisation.render_heatmap(
65
+ board,
66
+ heatmap,
67
+ arrows=[(best_legal_move.from_square, best_legal_move.to_square)],
68
+ )
69
+ with open(f"{constants.FIGURES_FOLER}/{file_id}.svg", "w") as f:
70
+ f.write(svg_board)
71
+ return (
72
+ features,
73
+ model_output,
74
+ file_id,
75
+ f"{constants.FIGURES_FOLER}/{file_id}.svg",
76
+ fig
77
+ )
78
+
79
+ with gr.Blocks() as interface:
80
+ with gr.Row():
81
+ with gr.Column():
82
+ root_fen = gr.Textbox(
83
+ label="Root FEN",
84
+ lines=1,
85
+ max_lines=1,
86
+ value=chess.STARTING_FEN,
87
+ )
88
+ traj_fen = gr.Textbox(
89
+ label="Trajectory FEN",
90
+ lines=1,
91
+ max_lines=1,
92
+ value="rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1",
93
+ )
94
+ compute_features = gr.Button("Compute features")
95
+
96
+ with gr.Group():
97
+ with gr.Row():
98
+ feature_index = gr.Slider(
99
+ label="Feature index",
100
+ minimum=0,
101
+ maximum=constants.N_FEATURES,
102
+ step=1,
103
+ value=0,
104
+ )
105
+
106
+ with gr.Group():
107
+ with gr.Row():
108
+ game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="")
109
+ with gr.Row():
110
+ colorbar = gr.Plot(label="Colorbar")
111
+ with gr.Column():
112
+ board_image = gr.Image(label="Board")
113
+
114
+ features = gr.State(None)
115
+ model_output = gr.State(None)
116
+ file_id = gr.State(None)
117
+ compute_features.click(
118
+ compute_features_fn,
119
+ inputs=[features, model_output, file_id, root_fen, traj_fen, feature_index],
120
+ outputs=[features, model_output, file_id, board_image, colorbar, game_info],
121
+ )
src/interfaces/stats_interface.py ADDED
File without changes
src/visualisation.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualisation utils.
3
+ """
4
+
5
+ import chess
6
+ import chess.svg
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000)
14
+ ALPHA = 1.0
15
+ NORM = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False)
16
+
17
+
18
+ def render_heatmap(
19
+ board,
20
+ heatmap,
21
+ square=None,
22
+ vmin=None,
23
+ vmax=None,
24
+ arrows=None,
25
+ normalise="none",
26
+ ):
27
+ """
28
+ Render a heatmap on the board.
29
+ """
30
+ if normalise == "abs":
31
+ a_max = heatmap.abs().max()
32
+ if a_max != 0:
33
+ heatmap = heatmap / a_max
34
+ vmin = -1
35
+ vmax = 1
36
+ if vmin is None:
37
+ vmin = heatmap.min()
38
+ if vmax is None:
39
+ vmax = heatmap.max()
40
+ norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
41
+
42
+ color_dict = {}
43
+ for square_index in range(64):
44
+ color = COLOR_MAP(norm(heatmap[square_index]))
45
+ color = (*color[:3], ALPHA)
46
+ color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True)
47
+ fig = plt.figure(figsize=(6, 0.6))
48
+ ax = plt.gca()
49
+ ax.axis("off")
50
+ fig.colorbar(
51
+ matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP),
52
+ ax=ax,
53
+ orientation="horizontal",
54
+ fraction=1.0,
55
+ )
56
+ if square is not None:
57
+ try:
58
+ check = chess.parse_square(square)
59
+ except ValueError:
60
+ check = None
61
+ else:
62
+ check = None
63
+ if arrows is None:
64
+ arrows = []
65
+ plt.close()
66
+ return (
67
+ chess.svg.board(
68
+ board,
69
+ check=check,
70
+ fill=color_dict,
71
+ size=350,
72
+ arrows=arrows,
73
+ ),
74
+ fig,
75
+ )
76
+
77
+
78
+ def render_policy_distribution(
79
+ policy,
80
+ legal_moves,
81
+ n_bins=20,
82
+ ):
83
+ """
84
+ Render the policy distribution histogram.
85
+ """
86
+ legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool()
87
+ fig = plt.figure(figsize=(6, 6))
88
+ ax = plt.gca()
89
+ _, bins = np.histogram(policy, bins=n_bins)
90
+ ax.hist(
91
+ policy[~legal_mask],
92
+ bins=bins,
93
+ alpha=0.5,
94
+ density=True,
95
+ label="Illegal moves",
96
+ )
97
+ ax.hist(
98
+ policy[legal_mask],
99
+ bins=bins,
100
+ alpha=0.5,
101
+ density=True,
102
+ label="Legal moves",
103
+ )
104
+ plt.xlabel("Policy")
105
+ plt.ylabel("Density")
106
+ plt.legend()
107
+ plt.yscale("log")
108
+ return fig