File size: 3,377 Bytes
33b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from dataclasses import dataclass, field
from typing import Optional
import torch
import datetime
@dataclass
class SDSAERunnerConfig():
image_size: int = 512,
num_sampling_steps: int = 25,
vae: str = "mse"
model_name: str = None
model_name_proc: str= None
timestep: int = 0
module_name: str = "mid_block"
paths_to_latents: str = None
layer_name:str = None
block_layer: int = 10
block_name: str = "text_encoder.text_model.encoder.layers.10.28"
use_cached_activations: bool = False
block_name :str = 'mid_block'
image_key: str = 'image'
# SAE Parameters
d_in: int = 768
k: int = 32
auxk_coef: float = 1 / 32
auxk: int = 32
# Activation Store Parameters
epoch:int = 1000
total_training_tokens: int = 2_000_000
eps: float = 6.25e-10
# SAE Parameters
b_dec_init_method: str = "mean"
expansion_factor: int = 4
from_pretrained_path: Optional[str] = None
# Training Parameters
lr: float = 3e-4
lr_scheduler_name: str = "constant"
lr_warm_up_steps: int = 500
batch_size: int = 4096
sae_batch_size: int = 1024,
dead_feature_threshold: float = 1e-8
dead_toks_threshold: int = 10_000_000
# WANDB
log_to_wandb: bool = True
wandb_project: str = "steerers"
wandb_entity: str = None
wandb_log_frequency: int = 10
# Misc
device: str = "cpu"
seed: int = 42
dtype: torch.dtype = torch.float32
save_path_base: str = "checkpoints"
max_batch_size: int = 32
ct: str = field(default_factory=lambda: datetime.datetime.now().isoformat())
save_interval: int = 5000
def __post_init__(self):
self.d_sae = self.d_in * self.expansion_factor
self.run_name = f"{self.block_name}_k{self.k}_hidden{self.d_sae}_auxk{self.auxk}_bs{self.batch_size}_lr{self.lr}"
self.checkpoint_path = f"{self.save_path_base}/{self.run_name}_{self.ct}"
if self.b_dec_init_method not in ["mean"]:
raise ValueError(
f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
)
self.device = torch.device(self.device)
print(
f"Run name: {self.d_sae}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
)
# Print out some useful info:
total_training_steps = self.total_training_tokens // self.batch_size
print(f"Total training steps: {total_training_steps}")
total_wandb_updates = total_training_steps // self.wandb_log_frequency
print(f"Total wandb updates: {total_wandb_updates}")
@property
def sae_name(self) -> str:
"""Returns the name of the SAE model based on key parameters."""
return f"{self.block_name}_k{self.k}_hidden{self.d_sae}_auxk{self.auxk}_bs{self.batch_size}_lr{self.lr}"
@property
def save_path(self) -> str:
"""Returns the path where the SAE model will be saved."""
return self.checkpoint_path
def __getitem__(self, key):
"""Allows subscripting the config object like a dictionary."""
if hasattr(self, key):
return getattr(self, key)
raise KeyError(f"Key {key} does not exist in SDSAERunnerConfig.")
|