clip-sae-128 / model.py
gytdau's picture
Upload model.py with huggingface_hub
9a3a9be verified
raw
history blame
2.01 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class SparseAutoencoder(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
sparsity_alpha=0.00004,
decoder_norm_range=(0.05, 1.0),
):
super(SparseAutoencoder, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.sparsity_alpha = sparsity_alpha
self.enc_bias = nn.Parameter(torch.zeros(hidden_dim))
self.encoder = nn.Linear(input_dim, hidden_dim, bias=False)
self.dec_bias = nn.Parameter(torch.zeros(input_dim))
self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)
self._initialize_weights(decoder_norm_range)
def forward(self, x):
encoded = self.encode(x)
decoded = self.decode(encoded)
return decoded, encoded
def encode(self, x):
return F.relu(self.encoder(x) + self.enc_bias)
def decode(self, x):
return self.decoder(x) + self.dec_bias
def loss(self, x, decoded, encoded):
reconstruction_loss = F.mse_loss(decoded, x)
sparsity_loss = self.sparsity_alpha * torch.sum(
encoded.abs() * self.decoder.weight.norm(p=2, dim=0)
)
total_loss = reconstruction_loss + sparsity_loss
return total_loss
def _initialize_weights(self, decoder_norm_range):
# Initialize encoder weights to the transpose of decoder weights
self.encoder.weight.data = self.decoder.weight.data.t()
# Initialize decoder weights with random directions and fixed L2 norm
norm_min, norm_max = decoder_norm_range
norm_range = norm_max - norm_min
self.decoder.weight.data.normal_(0, 1)
self.decoder.weight.data /= self.decoder.weight.data.norm(
p=2, dim=1, keepdim=True
)
self.decoder.weight.data *= (
norm_min + norm_range * torch.rand(1, self.hidden_dim)
).expand_as(self.decoder.weight.data)