Spaces:
Runtime error
Runtime error
""" | |
Defines the dictionary classes | |
""" | |
import torch | |
import torch.nn as nn | |
from tensordict import TensorDict | |
class SparseAutoEncoder(nn.Module): | |
""" | |
A 2-layer sparse autoencoder. | |
""" | |
def __init__( | |
self, | |
activation_dim, | |
dict_size, | |
pre_bias=False, | |
init_normalise_dict=None, | |
): | |
super().__init__() | |
self.activation_dim = activation_dim | |
self.dict_size = dict_size | |
self.pre_bias = pre_bias | |
self.init_normalise_dict = init_normalise_dict | |
self.b_enc = nn.Parameter(torch.zeros(self.dict_size)) | |
self.relu = nn.ReLU() | |
self.W_dec = nn.Parameter( | |
torch.nn.init.kaiming_uniform_( | |
torch.empty( | |
self.dict_size, | |
self.activation_dim, | |
) | |
) | |
) | |
if init_normalise_dict == "l2": | |
self.normalize_dict_(less_than_1=False) | |
self.W_dec *= 0.1 | |
elif init_normalise_dict == "less_than_1": | |
self.normalize_dict_(less_than_1=True) | |
self.W_enc = nn.Parameter(self.W_dec.t()) | |
self.b_dec = nn.Parameter( | |
torch.zeros( | |
self.activation_dim, | |
) | |
) | |
def normalize_dict_( | |
self, | |
less_than_1=False, | |
): | |
norm = self.W_dec.norm(dim=1) | |
positive_mask = norm != 0 | |
if less_than_1: | |
greater_than_1_mask = (norm > 1) & (positive_mask) | |
self.W_dec[greater_than_1_mask] /= norm[greater_than_1_mask].unsqueeze(1) | |
else: | |
self.W_dec[positive_mask] /= norm[positive_mask].unsqueeze(1) | |
def encode(self, x): | |
return x @ self.W_enc + self.b_enc | |
def decode(self, f): | |
return f @ self.W_dec + self.b_dec | |
def forward(self, x, output_features=False, ghost_mask=None): | |
""" | |
Forward pass of an autoencoder. | |
x : activations to be autoencoded | |
output_features : if True, return the encoded features as well | |
as the decoded x | |
ghost_mask : if not None, run this autoencoder in "ghost mode" | |
where features are masked | |
""" | |
if self.pre_bias: | |
x = x - self.b_dec | |
f_pre = self.encode(x) | |
out = TensorDict({}, batch_size=x.shape[0]) | |
if ghost_mask is not None: | |
f_ghost = torch.exp(f_pre) * ghost_mask.to(f_pre) | |
x_ghost = f_ghost @ self.W_dec | |
out["x_ghost"] = x_ghost | |
f = self.relu(f_pre) | |
if output_features: | |
out["features"] = f | |
x_hat = self.decode(f) | |
out["x_hat"] = x_hat | |
return out | |