surokpro2 commited on
Commit
46611c9
1 Parent(s): 32122c4

Update SAE/sae.py

Browse files
Files changed (1) hide show
  1. SAE/sae.py +5 -9
SAE/sae.py CHANGED
@@ -41,13 +41,6 @@ class SparseAutoencoder(nn.Module):
41
  self.stats_last_nonzero: torch.Tensor
42
  self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
43
 
44
- def auxk_mask_fn(x):
45
- dead_mask = self.stats_last_nonzero > dead_steps_threshold
46
- x.data *= dead_mask # inplace to save memory
47
- return x
48
-
49
- self.auxk_mask_fn = auxk_mask_fn
50
-
51
  ## initialization
52
 
53
  # "tied" init
@@ -58,6 +51,11 @@ class SparseAutoencoder(nn.Module):
58
 
59
  unit_norm_decoder_(self)
60
 
 
 
 
 
 
61
  def save_to_disk(self, path: str):
62
  PATH_TO_CFG = 'config.json'
63
  PATH_TO_WEIGHTS = 'state_dict.pth'
@@ -122,7 +120,6 @@ class SparseAutoencoder(nn.Module):
122
 
123
  return latents
124
 
125
- @spaces.GPU
126
  def forward(self, x):
127
  x = x - self.pre_bias
128
  latents_pre_act = self.encoder(x) + self.latent_bias
@@ -182,7 +179,6 @@ class SparseAutoencoder(nn.Module):
182
  "auxk_vals": auxk_vals,
183
  }
184
 
185
- @spaces.GPU
186
  def decode_sparse(self, inds, vals):
187
  rows, cols = inds.shape[0], self.n_dirs
188
 
 
41
  self.stats_last_nonzero: torch.Tensor
42
  self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
43
 
 
 
 
 
 
 
 
44
  ## initialization
45
 
46
  # "tied" init
 
51
 
52
  unit_norm_decoder_(self)
53
 
54
+ def auxk_mask_fn(self, x):
55
+ dead_mask = self.stats_last_nonzero > dead_steps_threshold
56
+ x.data *= dead_mask # inplace to save memory
57
+ return x
58
+
59
  def save_to_disk(self, path: str):
60
  PATH_TO_CFG = 'config.json'
61
  PATH_TO_WEIGHTS = 'state_dict.pth'
 
120
 
121
  return latents
122
 
 
123
  def forward(self, x):
124
  x = x - self.pre_bias
125
  latents_pre_act = self.encoder(x) + self.latent_bias
 
179
  "auxk_vals": auxk_vals,
180
  }
181
 
 
182
  def decode_sparse(self, inds, vals):
183
  rows, cols = inds.shape[0], self.n_dirs
184