Safetensors
llama3_SAE
custom_code
File size: 1,419 Bytes
a29b74a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from transformers import PretrainedConfig, LlamaConfig
from typing import List, Callable
import torch


# class LLama3_SAE_Config(PretrainedConfig):
class LLama3_SAE_Config(LlamaConfig):
    model_type = "llama3_SAE"

    def __init__(
        self,
        # hf_token: str = "",
        # base_model_config: LlamaConfig = None,
        base_model_name: str = "",
        hook_block_num: int = 25,
        n_latents: int = 12288,
        n_inputs: int = 4096,
        activation: str = "relu",
        activation_k: int = 64,
        site: str = "mlp",
        tied: bool = False,
        normalize: bool = False,
        mod_features: List[int] = None,
        mod_threshold: List[int] = None,
        mod_replacement: List[int] = None,
        mod_scaling: List[int] = None,
        **kwargs,
    ):
        # self.hf_token = hf_token
        # self.base_model_config = base_model_config
        self.base_model_name = base_model_name
        self.hook_block_num = hook_block_num
        self.n_latents = n_latents
        self.n_inputs = n_inputs
        self.activation = activation
        self.activation_k = activation_k
        self.site = site
        self.tied = tied
        self.normalize = normalize
        self.mod_features = mod_features
        self.mod_threshold = mod_threshold
        self.mod_replacement = mod_replacement
        self.mod_scaling = mod_scaling

        super().__init__(**kwargs)