File size: 1,673 Bytes
0d998a6
 
 
 
 
 
 
 
3b6ef01
0d998a6
bd23dfe
0d998a6
 
 
 
 
 
3b6ef01
0d998a6
 
 
 
 
 
 
3b6ef01
0d998a6
bd23dfe
0d998a6
bd23dfe
5f19208
0d998a6
 
 
bd23dfe
5f19208
0d998a6
 
 
bd23dfe
0d998a6
bd23dfe
 
 
 
 
 
 
 
0d998a6
 
 
 
 
 
 
bd23dfe
0d998a6
3b6ef01
 
 
 
 
0d998a6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Manage global variables for the app.
"""

from huggingface_hub import HfApi

import gradio as gr
from lczerolens import ModelWrapper
import torch
from datasets import load_dataset, Dataset

from src import constants
from src.helpers import SparseAutoEncoder, OutputGenerator

hf_api: HfApi
wrapper: ModelWrapper
sae: SparseAutoEncoder
generator: OutputGenerator
f_ds: Dataset


def setup():
    global hf_api
    global wrapper
    global sae
    global generator
    global f_ds

    hf_api = HfApi(token=constants.HF_TOKEN)
    hf_api.snapshot_download(
        local_dir=f"{constants.ASSETS_FOLDER}/models",
        repo_id="lczero-planning/models",
        repo_type="model",
    )
    hf_api.snapshot_download(
        local_dir=f"{constants.ASSETS_FOLDER}/saes",
        repo_id="lczero-planning/saes",
        repo_type="model",
    )

    wrapper = ModelWrapper.from_onnx_path(f"{constants.ASSETS_FOLDER}/models/{constants.MODEL_NAME}").to(constants.DEVICE)
    sae_dict = torch.load(
        f"{constants.ASSETS_FOLDER}/saes/{constants.SAE_CONFIG}/model.pt",
        map_location=constants.DEVICE,
    )
    sae = SparseAutoEncoder(
        constants.ACTIVATION_DIM,
        constants.DICTIONARY_SIZE,
        pre_bias=constants.PRE_BIAS,
        init_normalise_dict=constants.INIT_NORMALISE_DICT,
    )
    sae.load_state_dict(
        sae_dict
    )
    generator = OutputGenerator(
        sae=sae,
        wrapper=wrapper,
        module_exp=rf".*block{constants.LAYER}/conv2/relu"
    )
    f_ds = load_dataset(
        constants.FEATURE_DATASET,
        constants.SAE_CONFIG, 
        split="test"
    ).with_format("torch")

if gr.NO_RELOAD:
    setup()