Spaces:
Running
on
Zero
Running
on
Zero
Update SAE/sae.py
Browse files- SAE/sae.py +5 -9
SAE/sae.py
CHANGED
@@ -41,13 +41,6 @@ class SparseAutoencoder(nn.Module):
|
|
41 |
self.stats_last_nonzero: torch.Tensor
|
42 |
self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
|
43 |
|
44 |
-
def auxk_mask_fn(x):
|
45 |
-
dead_mask = self.stats_last_nonzero > dead_steps_threshold
|
46 |
-
x.data *= dead_mask # inplace to save memory
|
47 |
-
return x
|
48 |
-
|
49 |
-
self.auxk_mask_fn = auxk_mask_fn
|
50 |
-
|
51 |
## initialization
|
52 |
|
53 |
# "tied" init
|
@@ -58,6 +51,11 @@ class SparseAutoencoder(nn.Module):
|
|
58 |
|
59 |
unit_norm_decoder_(self)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
61 |
def save_to_disk(self, path: str):
|
62 |
PATH_TO_CFG = 'config.json'
|
63 |
PATH_TO_WEIGHTS = 'state_dict.pth'
|
@@ -122,7 +120,6 @@ class SparseAutoencoder(nn.Module):
|
|
122 |
|
123 |
return latents
|
124 |
|
125 |
-
@spaces.GPU
|
126 |
def forward(self, x):
|
127 |
x = x - self.pre_bias
|
128 |
latents_pre_act = self.encoder(x) + self.latent_bias
|
@@ -182,7 +179,6 @@ class SparseAutoencoder(nn.Module):
|
|
182 |
"auxk_vals": auxk_vals,
|
183 |
}
|
184 |
|
185 |
-
@spaces.GPU
|
186 |
def decode_sparse(self, inds, vals):
|
187 |
rows, cols = inds.shape[0], self.n_dirs
|
188 |
|
|
|
41 |
self.stats_last_nonzero: torch.Tensor
|
42 |
self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
## initialization
|
45 |
|
46 |
# "tied" init
|
|
|
51 |
|
52 |
unit_norm_decoder_(self)
|
53 |
|
54 |
+
def auxk_mask_fn(self, x):
|
55 |
+
dead_mask = self.stats_last_nonzero > dead_steps_threshold
|
56 |
+
x.data *= dead_mask # inplace to save memory
|
57 |
+
return x
|
58 |
+
|
59 |
def save_to_disk(self, path: str):
|
60 |
PATH_TO_CFG = 'config.json'
|
61 |
PATH_TO_WEIGHTS = 'state_dict.pth'
|
|
|
120 |
|
121 |
return latents
|
122 |
|
|
|
123 |
def forward(self, x):
|
124 |
x = x - self.pre_bias
|
125 |
latents_pre_act = self.encoder(x) + self.latent_bias
|
|
|
179 |
"auxk_vals": auxk_vals,
|
180 |
}
|
181 |
|
|
|
182 |
def decode_sparse(self, inds, vals):
|
183 |
rows, cols = inds.shape[0], self.n_dirs
|
184 |
|