Spaces:
Sleeping
Sleeping
Nu Appleblossom
commited on
Commit
·
9eef69d
1
Parent(s):
79e1a94
app refactored222
Browse files
app.py
CHANGED
@@ -114,10 +114,18 @@ def create_feature_vector(w_enc, w_dec, feature_number, weight_type, token_centr
|
|
114 |
return feature_vector
|
115 |
|
116 |
def perform_pca(_embeddings):
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
@torch.no_grad()
|
123 |
def create_ghost_token(_feature_vector, _token_centroid, _pca_direction, target_distance, pca_weight):
|
@@ -174,38 +182,50 @@ def initialize_resources():
|
|
174 |
def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp):
|
175 |
global w_enc_dict, w_dec_dict
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
def gradio_interface():
|
207 |
with gr.Blocks() as demo:
|
208 |
-
gr.Markdown("# Gemma-2B SAE Feature
|
209 |
|
210 |
with gr.Row():
|
211 |
with gr.Column(scale=2):
|
|
|
114 |
return feature_vector
|
115 |
|
116 |
def perform_pca(_embeddings):
|
117 |
+
try:
|
118 |
+
pca = PCA(n_components=1)
|
119 |
+
embeddings_cpu = _embeddings.cpu().numpy()
|
120 |
+
pca.fit(embeddings_cpu)
|
121 |
+
pca_direction = torch.tensor(pca.components_[0], dtype=config.DTYPE, device=config.DEVICE)
|
122 |
+
return F.normalize(pca_direction, p=2, dim=0)
|
123 |
+
except Exception as e:
|
124 |
+
logger.error(f"Error in perform_pca: {str(e)}")
|
125 |
+
logger.error(traceback.format_exc())
|
126 |
+
raise RuntimeError(f"PCA calculation failed: {str(e)}")
|
127 |
+
|
128 |
+
|
129 |
|
130 |
@torch.no_grad()
|
131 |
def create_ghost_token(_feature_vector, _token_centroid, _pca_direction, target_distance, pca_weight):
|
|
|
182 |
def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp):
|
183 |
global w_enc_dict, w_dec_dict
|
184 |
|
185 |
+
try:
|
186 |
+
if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
|
187 |
+
w_enc, w_dec = load_sae_weights(selected_sae)
|
188 |
+
if w_enc is None or w_dec is None:
|
189 |
+
return f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection."
|
190 |
+
w_enc_dict[selected_sae] = w_enc
|
191 |
+
w_dec_dict[selected_sae] = w_dec
|
192 |
+
else:
|
193 |
+
w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae]
|
194 |
+
|
195 |
+
token_centroid = torch.mean(token_embeddings, dim=0)
|
196 |
+
feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor)
|
197 |
+
|
198 |
+
if use_pca:
|
199 |
+
logger.info("Performing PCA...")
|
200 |
+
try:
|
201 |
+
pca_direction = perform_pca(token_embeddings)
|
202 |
+
feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight)
|
203 |
+
logger.info("PCA completed successfully.")
|
204 |
+
except Exception as pca_error:
|
205 |
+
logger.error(f"Error during PCA: {str(pca_error)}")
|
206 |
+
return f"Error during PCA: {str(pca_error)}"
|
207 |
+
|
208 |
+
closest_tokens_with_values = find_closest_tokens(
|
209 |
+
feature_vector, token_embeddings, tokenizer,
|
210 |
+
top_k=500, num_exp=num_exp, denom_exp=denom_exp
|
211 |
+
)
|
212 |
+
|
213 |
+
token_list = [token for token, _ in closest_tokens_with_values]
|
214 |
+
result = f"100 tokens whose embeddings produce the smallest ratio:\n\n"
|
215 |
+
result += f"[{', '.join(repr(token) for token in token_list[:100])}]\n\n"
|
216 |
+
result += "Top 500 list:\n"
|
217 |
+
result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values])
|
218 |
+
|
219 |
+
return result
|
220 |
+
except Exception as e:
|
221 |
+
logger.error(f"Error in process_input: {str(e)}")
|
222 |
+
logger.error(traceback.format_exc())
|
223 |
+
return f"Error: {str(e)}"
|
224 |
+
|
225 |
|
226 |
def gradio_interface():
|
227 |
with gr.Blocks() as demo:
|
228 |
+
gr.Markdown("# Gemma-2B SAE Feature Explorer")
|
229 |
|
230 |
with gr.Row():
|
231 |
with gr.Column(scale=2):
|