from transformers import PretrainedConfig, LlamaConfig from typing import List, Callable import torch # class LLama3_SAE_Config(PretrainedConfig): class LLama3_SAE_Config(LlamaConfig): model_type = "llama3_SAE" def __init__( self, # hf_token: str = "", # base_model_config: LlamaConfig = None, base_model_name: str = "", hook_block_num: int = 25, n_latents: int = 12288, n_inputs: int = 4096, activation: str = "relu", activation_k: int = 64, site: str = "mlp", tied: bool = False, normalize: bool = False, mod_features: List[int] = None, mod_threshold: List[int] = None, mod_replacement: List[int] = None, mod_scaling: List[int] = None, **kwargs, ): # self.hf_token = hf_token # self.base_model_config = base_model_config self.base_model_name = base_model_name self.hook_block_num = hook_block_num self.n_latents = n_latents self.n_inputs = n_inputs self.activation = activation self.activation_k = activation_k self.site = site self.tied = tied self.normalize = normalize self.mod_features = mod_features self.mod_threshold = mod_threshold self.mod_replacement = mod_replacement self.mod_scaling = mod_scaling super().__init__(**kwargs)