gytdau commited on
Commit
9a3a9be
1 Parent(s): 7cb9d40

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +59 -0
model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SparseAutoencoder(nn.Module):
7
+ def __init__(
8
+ self,
9
+ input_dim,
10
+ hidden_dim,
11
+ sparsity_alpha=0.00004,
12
+ decoder_norm_range=(0.05, 1.0),
13
+ ):
14
+ super(SparseAutoencoder, self).__init__()
15
+ self.input_dim = input_dim
16
+ self.hidden_dim = hidden_dim
17
+ self.sparsity_alpha = sparsity_alpha
18
+
19
+ self.enc_bias = nn.Parameter(torch.zeros(hidden_dim))
20
+ self.encoder = nn.Linear(input_dim, hidden_dim, bias=False)
21
+
22
+ self.dec_bias = nn.Parameter(torch.zeros(input_dim))
23
+ self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)
24
+
25
+ self._initialize_weights(decoder_norm_range)
26
+
27
+ def forward(self, x):
28
+ encoded = self.encode(x)
29
+ decoded = self.decode(encoded)
30
+ return decoded, encoded
31
+
32
+ def encode(self, x):
33
+ return F.relu(self.encoder(x) + self.enc_bias)
34
+
35
+ def decode(self, x):
36
+ return self.decoder(x) + self.dec_bias
37
+
38
+ def loss(self, x, decoded, encoded):
39
+ reconstruction_loss = F.mse_loss(decoded, x)
40
+ sparsity_loss = self.sparsity_alpha * torch.sum(
41
+ encoded.abs() * self.decoder.weight.norm(p=2, dim=0)
42
+ )
43
+ total_loss = reconstruction_loss + sparsity_loss
44
+ return total_loss
45
+
46
+ def _initialize_weights(self, decoder_norm_range):
47
+ # Initialize encoder weights to the transpose of decoder weights
48
+ self.encoder.weight.data = self.decoder.weight.data.t()
49
+
50
+ # Initialize decoder weights with random directions and fixed L2 norm
51
+ norm_min, norm_max = decoder_norm_range
52
+ norm_range = norm_max - norm_min
53
+ self.decoder.weight.data.normal_(0, 1)
54
+ self.decoder.weight.data /= self.decoder.weight.data.norm(
55
+ p=2, dim=1, keepdim=True
56
+ )
57
+ self.decoder.weight.data *= (
58
+ norm_min + norm_range * torch.rand(1, self.hidden_dim)
59
+ ).expand_as(self.decoder.weight.data)