demo / src /helpers /sae.py
Xmaster6y's picture
new repo structure
0d998a6 unverified
"""
Defines the dictionary classes
"""
import torch
import torch.nn as nn
from tensordict import TensorDict
class SparseAutoEncoder(nn.Module):
"""
A 2-layer sparse autoencoder.
"""
def __init__(
self,
activation_dim,
dict_size,
pre_bias=False,
init_normalise_dict=None,
):
super().__init__()
self.activation_dim = activation_dim
self.dict_size = dict_size
self.pre_bias = pre_bias
self.init_normalise_dict = init_normalise_dict
self.b_enc = nn.Parameter(torch.zeros(self.dict_size))
self.relu = nn.ReLU()
self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.dict_size,
self.activation_dim,
)
)
)
if init_normalise_dict == "l2":
self.normalize_dict_(less_than_1=False)
self.W_dec *= 0.1
elif init_normalise_dict == "less_than_1":
self.normalize_dict_(less_than_1=True)
self.W_enc = nn.Parameter(self.W_dec.t())
self.b_dec = nn.Parameter(
torch.zeros(
self.activation_dim,
)
)
@torch.no_grad()
def normalize_dict_(
self,
less_than_1=False,
):
norm = self.W_dec.norm(dim=1)
positive_mask = norm != 0
if less_than_1:
greater_than_1_mask = (norm > 1) & (positive_mask)
self.W_dec[greater_than_1_mask] /= norm[greater_than_1_mask].unsqueeze(1)
else:
self.W_dec[positive_mask] /= norm[positive_mask].unsqueeze(1)
def encode(self, x):
return x @ self.W_enc + self.b_enc
def decode(self, f):
return f @ self.W_dec + self.b_dec
def forward(self, x, output_features=False, ghost_mask=None):
"""
Forward pass of an autoencoder.
x : activations to be autoencoded
output_features : if True, return the encoded features as well
as the decoded x
ghost_mask : if not None, run this autoencoder in "ghost mode"
where features are masked
"""
if self.pre_bias:
x = x - self.b_dec
f_pre = self.encode(x)
out = TensorDict({}, batch_size=x.shape[0])
if ghost_mask is not None:
f_ghost = torch.exp(f_pre) * ghost_mask.to(f_pre)
x_ghost = f_ghost @ self.W_dec
out["x_ghost"] = x_ghost
f = self.relu(f_pre)
if output_features:
out["features"] = f
x_hat = self.decode(f)
out["x_hat"] = x_hat
return out