|
''' |
|
Adapted from |
|
https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/train.py |
|
''' |
|
|
|
|
|
import os |
|
import sys |
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) |
|
from typing import Callable, Iterable, Iterator |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.distributed import ReduceOp |
|
from SAE.dataset_iterator import ActivationsDataloader |
|
from SAE.sae import SparseAutoencoder, unit_norm_decoder_, unit_norm_decoder_grad_adjustment_ |
|
from SAE.sae_utils import SAETrainingConfig, Config |
|
|
|
from types import SimpleNamespace |
|
from typing import Optional, List |
|
import json |
|
|
|
import tqdm |
|
|
|
def weighted_average(points: torch.Tensor, weights: torch.Tensor): |
|
weights = weights / weights.sum() |
|
return (points * weights.view(-1, 1)).sum(dim=0) |
|
|
|
|
|
@torch.no_grad() |
|
def geometric_median_objective( |
|
median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor |
|
) -> torch.Tensor: |
|
|
|
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) |
|
|
|
return (norms * weights).sum() |
|
|
|
|
|
def compute_geometric_median( |
|
points: torch.Tensor, |
|
weights: Optional[torch.Tensor] = None, |
|
eps: float = 1e-6, |
|
maxiter: int = 100, |
|
ftol: float = 1e-20, |
|
do_log: bool = False, |
|
): |
|
""" |
|
:param points: ``torch.Tensor`` of shape ``(n, d)`` |
|
:param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``. |
|
:param eps: Smallest allowed value of denominator, to avoid divide by zero. |
|
Equivalently, this is a smoothing parameter. Default 1e-6. |
|
:param maxiter: Maximum number of Weiszfeld iterations. Default 100 |
|
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20. |
|
:param do_log: If true will return a log of function values encountered through the course of the algorithm |
|
:return: SimpleNamespace object with fields |
|
- `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)`` |
|
- `termination`: string explaining how the algorithm terminated. |
|
- `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false). |
|
""" |
|
with torch.no_grad(): |
|
|
|
if weights is None: |
|
weights = torch.ones((points.shape[0],), device=points.device) |
|
|
|
new_weights = weights |
|
median = weighted_average(points, weights) |
|
objective_value = geometric_median_objective(median, points, weights) |
|
if do_log: |
|
logs = [objective_value] |
|
else: |
|
logs = None |
|
|
|
|
|
early_termination = False |
|
pbar = tqdm.tqdm(range(maxiter)) |
|
for _ in pbar: |
|
prev_obj_value = objective_value |
|
|
|
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) |
|
new_weights = weights / torch.clamp(norms, min=eps) |
|
median = weighted_average(points, new_weights) |
|
objective_value = geometric_median_objective(median, points, weights) |
|
|
|
if logs is not None: |
|
logs.append(objective_value) |
|
if abs(prev_obj_value - objective_value) <= ftol * objective_value: |
|
early_termination = True |
|
break |
|
|
|
pbar.set_description(f"Objective value: {objective_value:.4f}") |
|
|
|
median = weighted_average(points, new_weights) |
|
return SimpleNamespace( |
|
median=median, |
|
new_weights=new_weights, |
|
termination=( |
|
"function value converged within tolerance" |
|
if early_termination |
|
else "maximum iterations reached" |
|
), |
|
logs=logs, |
|
) |
|
|
|
def maybe_transpose(x): |
|
return x.T if not x.is_contiguous() and x.T.is_contiguous() else x |
|
|
|
import wandb |
|
|
|
RANK = 0 |
|
|
|
class Logger: |
|
def __init__(self, sae_name, **kws): |
|
self.vals = {} |
|
self.enabled = (RANK == 0) and not kws.pop("dummy", False) |
|
self.sae_name = sae_name |
|
|
|
def logkv(self, k, v): |
|
if self.enabled: |
|
self.vals[f'{self.sae_name}/{k}'] = v.detach() if isinstance(v, torch.Tensor) else v |
|
return v |
|
|
|
def dumpkvs(self, step): |
|
if self.enabled: |
|
wandb.log(self.vals, step=step) |
|
self.vals = {} |
|
|
|
|
|
class FeaturesStats: |
|
def __init__(self, dim, logger): |
|
self.dim = dim |
|
self.logger = logger |
|
self.reinit() |
|
|
|
def reinit(self): |
|
self.n_activated = torch.zeros(self.dim, dtype=torch.long, device="cuda") |
|
self.n = 0 |
|
|
|
def update(self, inds): |
|
self.n += inds.shape[0] |
|
inds = inds.flatten().detach() |
|
self.n_activated.scatter_add_(0, inds, torch.ones_like(inds)) |
|
|
|
def log(self): |
|
self.logger.logkv('activated', (self.n_activated / self.n + 1e-9).log10().cpu().numpy()) |
|
|
|
def training_loop_( |
|
aes, |
|
train_acts_iter, |
|
loss_fn, |
|
log_interval, |
|
save_interval, |
|
loggers, |
|
sae_cfgs, |
|
): |
|
sae_packs = [] |
|
for ae, cfg, logger in zip(aes, sae_cfgs, loggers): |
|
pbar = tqdm.tqdm(unit=" steps", desc="Training Loss: ") |
|
fstats = FeaturesStats(ae.n_dirs, logger) |
|
opt = torch.optim.Adam(ae.parameters(), lr=cfg.lr, eps=cfg.eps, fused=True) |
|
sae_packs.append((ae, cfg, logger, pbar, fstats, opt)) |
|
|
|
for i, flat_acts_train_batch in enumerate(train_acts_iter): |
|
flat_acts_train_batch = flat_acts_train_batch.cuda() |
|
|
|
for ae, cfg, logger, pbar, fstats, opt in sae_packs: |
|
recons, info = ae(flat_acts_train_batch) |
|
loss = loss_fn(ae, cfg, flat_acts_train_batch, recons, info, logger) |
|
|
|
fstats.update(info['inds']) |
|
|
|
bs = flat_acts_train_batch.shape[0] |
|
logger.logkv('not-activated 1e4', (ae.stats_last_nonzero > 1e4 / bs).mean(dtype=float).item()) |
|
logger.logkv('not-activated 1e6', (ae.stats_last_nonzero > 1e6 / bs).mean(dtype=float).item()) |
|
logger.logkv('not-activated 1e7', (ae.stats_last_nonzero > 1e7 / bs).mean(dtype=float).item()) |
|
|
|
logger.logkv('explained variance', explained_variance(recons, flat_acts_train_batch)) |
|
logger.logkv('l2_div', (torch.linalg.norm(recons, dim=1) / torch.linalg.norm(flat_acts_train_batch, dim=1)).mean()) |
|
|
|
if (i + 1) % log_interval == 0: |
|
fstats.log() |
|
fstats.reinit() |
|
|
|
if (i + 1) % save_interval == 0: |
|
ae.save_to_disk(f"{cfg.save_path}/{i + 1}") |
|
|
|
loss.backward() |
|
|
|
unit_norm_decoder_(ae) |
|
unit_norm_decoder_grad_adjustment_(ae) |
|
|
|
opt.step() |
|
opt.zero_grad() |
|
logger.dumpkvs(i) |
|
|
|
pbar.set_description(f"Training Loss {loss.item():.4f}") |
|
pbar.update(1) |
|
|
|
|
|
for ae, cfg, logger, pbar, fstats, opt in sae_packs: |
|
pbar.close() |
|
ae.save_to_disk(f"{cfg.save_path}/final") |
|
|
|
|
|
def init_from_data_(ae, stats_acts_sample): |
|
ae.pre_bias.data = ( |
|
compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float() |
|
) |
|
|
|
|
|
def mse(recons, x): |
|
|
|
return ((recons - x) ** 2).mean() |
|
|
|
def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor: |
|
|
|
xs_mu = xs.mean(dim=0) |
|
|
|
loss = mse(recon, xs) / mse( |
|
xs_mu[None, :].broadcast_to(xs.shape), xs |
|
) |
|
|
|
return loss |
|
|
|
def explained_variance(recons, x): |
|
|
|
diff = x - recons |
|
diff_var = torch.var(diff, dim=0, unbiased=False) |
|
|
|
|
|
x_var = torch.var(x, dim=0, unbiased=False) |
|
|
|
|
|
explained_var = 1 - diff_var / (x_var + 1e-8) |
|
|
|
return explained_var.mean() |
|
|
|
|
|
def main(): |
|
cfg = Config(json.load(open('SAE/config.json'))) |
|
|
|
dataloader = ActivationsDataloader(cfg.paths_to_latents, cfg.block_name, cfg.bs) |
|
|
|
acts_iter = dataloader.iterate() |
|
stats_acts_sample = torch.cat([ |
|
next(acts_iter).cpu() for _ in range(10) |
|
], dim=0) |
|
|
|
aes = [ |
|
SparseAutoencoder( |
|
n_dirs_local=sae.n_dirs, |
|
d_model=sae.d_model, |
|
k=sae.k, |
|
auxk=sae.auxk, |
|
dead_steps_threshold=sae.dead_toks_threshold // cfg.bs, |
|
).cuda() |
|
for sae in cfg.saes |
|
] |
|
|
|
for ae in aes: |
|
init_from_data_(ae, stats_acts_sample) |
|
|
|
mse_scale = ( |
|
1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean() |
|
) |
|
mse_scale = mse_scale.item() |
|
del stats_acts_sample |
|
|
|
wandb.init( |
|
project=cfg.wandb_project, |
|
name=cfg.wandb_name, |
|
) |
|
|
|
loggers = [Logger( |
|
sae_name=cfg_sae.sae_name, |
|
dummy=False, |
|
) for cfg_sae in cfg.saes] |
|
|
|
training_loop_( |
|
aes, |
|
acts_iter, |
|
lambda ae, cfg_sae, flat_acts_train_batch, recons, info, logger: ( |
|
|
|
logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch)) |
|
|
|
+ logger.logkv( |
|
"train_maxk_recons", |
|
cfg_sae.auxk_coef |
|
* normalized_mse( |
|
ae.decode_sparse( |
|
info["auxk_inds"], |
|
info["auxk_vals"], |
|
), |
|
flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(), |
|
).nan_to_num(0), |
|
) |
|
), |
|
sae_cfgs = cfg.saes, |
|
loggers=loggers, |
|
log_interval=cfg.log_interval, |
|
save_interval=cfg.save_interval, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |