|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class SparseAutoencoder(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim, |
|
hidden_dim, |
|
sparsity_alpha=0.00004, |
|
decoder_norm_range=(0.05, 1.0), |
|
): |
|
super(SparseAutoencoder, self).__init__() |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
self.sparsity_alpha = sparsity_alpha |
|
|
|
self.enc_bias = nn.Parameter(torch.zeros(hidden_dim)) |
|
self.encoder = nn.Linear(input_dim, hidden_dim, bias=False) |
|
|
|
self.dec_bias = nn.Parameter(torch.zeros(input_dim)) |
|
self.decoder = nn.Linear(hidden_dim, input_dim, bias=False) |
|
|
|
self._initialize_weights(decoder_norm_range) |
|
|
|
def forward(self, x): |
|
encoded = self.encode(x) |
|
decoded = self.decode(encoded) |
|
return decoded, encoded |
|
|
|
def encode(self, x): |
|
return F.relu(self.encoder(x) + self.enc_bias) |
|
|
|
def decode(self, x): |
|
return self.decoder(x) + self.dec_bias |
|
|
|
def loss(self, x, decoded, encoded): |
|
reconstruction_loss = F.mse_loss(decoded, x) |
|
sparsity_loss = self.sparsity_alpha * torch.sum( |
|
encoded.abs() * self.decoder.weight.norm(p=2, dim=0) |
|
) |
|
total_loss = reconstruction_loss + sparsity_loss |
|
return total_loss |
|
|
|
def _initialize_weights(self, decoder_norm_range): |
|
|
|
self.encoder.weight.data = self.decoder.weight.data.t() |
|
|
|
|
|
norm_min, norm_max = decoder_norm_range |
|
norm_range = norm_max - norm_min |
|
self.decoder.weight.data.normal_(0, 1) |
|
self.decoder.weight.data /= self.decoder.weight.data.norm( |
|
p=2, dim=1, keepdim=True |
|
) |
|
self.decoder.weight.data *= ( |
|
norm_min + norm_range * torch.rand(1, self.hidden_dim) |
|
).expand_as(self.decoder.weight.data) |
|
|