import torch import torch.nn as nn import torch.nn.functional as F class SparseAutoencoder(nn.Module): def __init__( self, input_dim, hidden_dim, sparsity_alpha=0.00004, decoder_norm_range=(0.05, 1.0), ): super(SparseAutoencoder, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.sparsity_alpha = sparsity_alpha self.enc_bias = nn.Parameter(torch.zeros(hidden_dim)) self.encoder = nn.Linear(input_dim, hidden_dim, bias=False) self.dec_bias = nn.Parameter(torch.zeros(input_dim)) self.decoder = nn.Linear(hidden_dim, input_dim, bias=False) self._initialize_weights(decoder_norm_range) def forward(self, x): encoded = self.encode(x) decoded = self.decode(encoded) return decoded, encoded def encode(self, x): return F.relu(self.encoder(x) + self.enc_bias) def decode(self, x): return self.decoder(x) + self.dec_bias def loss(self, x, decoded, encoded): reconstruction_loss = F.mse_loss(decoded, x) sparsity_loss = self.sparsity_alpha * torch.sum( encoded.abs() * self.decoder.weight.norm(p=2, dim=0) ) total_loss = reconstruction_loss + sparsity_loss return total_loss def _initialize_weights(self, decoder_norm_range): # Initialize encoder weights to the transpose of decoder weights self.encoder.weight.data = self.decoder.weight.data.t() # Initialize decoder weights with random directions and fixed L2 norm norm_min, norm_max = decoder_norm_range norm_range = norm_max - norm_min self.decoder.weight.data.normal_(0, 1) self.decoder.weight.data /= self.decoder.weight.data.norm( p=2, dim=1, keepdim=True ) self.decoder.weight.data *= ( norm_min + norm_range * torch.rand(1, self.hidden_dim) ).expand_as(self.decoder.weight.data)