Spaces:
Runtime error
Runtime error
constants
Browse files- src/constants.py +4 -1
- src/global_variables.py +14 -10
src/constants.py
CHANGED
@@ -15,4 +15,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
15 |
MODEL_NAME = "lc0-10-4238.onnx"
|
16 |
SAE_CONFIG = "debug"
|
17 |
LAYER = 9
|
18 |
-
|
|
|
|
|
|
|
|
15 |
MODEL_NAME = "lc0-10-4238.onnx"
|
16 |
SAE_CONFIG = "debug"
|
17 |
LAYER = 9
|
18 |
+
ACTIVATION_DIM = 256
|
19 |
+
DICTIONARY_SIZE = 7680
|
20 |
+
PRE_BIAS = False
|
21 |
+
INIT_NORMALISE_DICT = None
|
src/global_variables.py
CHANGED
@@ -7,7 +7,7 @@ import gradio as gr
|
|
7 |
from lczerolens import ModelWrapper
|
8 |
import torch
|
9 |
|
10 |
-
from src
|
11 |
from src.helpers import SparseAutoEncoder, OutputGenerator
|
12 |
|
13 |
hf_api: HfApi
|
@@ -22,32 +22,36 @@ def setup():
|
|
22 |
global sae
|
23 |
global generator
|
24 |
|
25 |
-
hf_api = HfApi(token=HF_TOKEN)
|
26 |
hf_api.snapshot_download(
|
27 |
-
local_dir=f"{ASSETS_FOLDER}/models",
|
28 |
repo_id="Xmaster6y/lczero-planning-models",
|
29 |
repo_type="model",
|
30 |
)
|
31 |
hf_api.snapshot_download(
|
32 |
-
local_dir=f"{ASSETS_FOLDER}/saes",
|
33 |
repo_id="Xmaster6y/lczero-planning-saes",
|
34 |
repo_type="model",
|
35 |
)
|
36 |
|
37 |
-
wrapper = ModelWrapper.from_onnx_path(f"{ASSETS_FOLDER}/models/{MODEL_NAME}").to(DEVICE)
|
38 |
sae_dict = torch.load(
|
39 |
-
f"{ASSETS_FOLDER}/saes/{SAE_CONFIG}/model.pt",
|
40 |
-
map_location=DEVICE,
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
42 |
)
|
43 |
-
sae = SparseAutoEncoder()
|
44 |
sae.load_state_dict(
|
45 |
sae_dict
|
46 |
)
|
47 |
generator = OutputGenerator(
|
48 |
sae=sae,
|
49 |
wrapper=wrapper,
|
50 |
-
module_exp=rf".*block{LAYER}/conv2/relu"
|
51 |
)
|
52 |
|
53 |
if gr.NO_RELOAD:
|
|
|
7 |
from lczerolens import ModelWrapper
|
8 |
import torch
|
9 |
|
10 |
+
from src import constants
|
11 |
from src.helpers import SparseAutoEncoder, OutputGenerator
|
12 |
|
13 |
hf_api: HfApi
|
|
|
22 |
global sae
|
23 |
global generator
|
24 |
|
25 |
+
hf_api = HfApi(token=constants.HF_TOKEN)
|
26 |
hf_api.snapshot_download(
|
27 |
+
local_dir=f"{constants.ASSETS_FOLDER}/models",
|
28 |
repo_id="Xmaster6y/lczero-planning-models",
|
29 |
repo_type="model",
|
30 |
)
|
31 |
hf_api.snapshot_download(
|
32 |
+
local_dir=f"{constants.ASSETS_FOLDER}/saes",
|
33 |
repo_id="Xmaster6y/lczero-planning-saes",
|
34 |
repo_type="model",
|
35 |
)
|
36 |
|
37 |
+
wrapper = ModelWrapper.from_onnx_path(f"{constants.ASSETS_FOLDER}/models/{constants.MODEL_NAME}").to(constants.DEVICE)
|
38 |
sae_dict = torch.load(
|
39 |
+
f"{constants.ASSETS_FOLDER}/saes/{constants.SAE_CONFIG}/model.pt",
|
40 |
+
map_location=constants.DEVICE,
|
41 |
+
)
|
42 |
+
sae = SparseAutoEncoder(
|
43 |
+
constants.ACTIVATION_DIM,
|
44 |
+
constants.DICTIONARY_SIZE,
|
45 |
+
pre_bias=constants.PRE_BIAS,
|
46 |
+
init_normalise_dict=constants.INIT_NORMALISE_DICT,
|
47 |
)
|
|
|
48 |
sae.load_state_dict(
|
49 |
sae_dict
|
50 |
)
|
51 |
generator = OutputGenerator(
|
52 |
sae=sae,
|
53 |
wrapper=wrapper,
|
54 |
+
module_exp=rf".*block{constants.LAYER}/conv2/relu"
|
55 |
)
|
56 |
|
57 |
if gr.NO_RELOAD:
|