Nu Appleblossom commited on
Commit
9eef69d
·
1 Parent(s): 79e1a94

app refactored222

Browse files
Files changed (1) hide show
  1. app.py +53 -33
app.py CHANGED
@@ -114,10 +114,18 @@ def create_feature_vector(w_enc, w_dec, feature_number, weight_type, token_centr
114
  return feature_vector
115
 
116
  def perform_pca(_embeddings):
117
- pca = PCA(n_components=1)
118
- pca.fit(_embeddings.cpu().numpy())
119
- pca_direction = torch.tensor(pca.components_[0], dtype=config.DTYPE, device=config.DEVICE)
120
- return F.normalize(pca_direction, p=2, dim=0)
 
 
 
 
 
 
 
 
121
 
122
  @torch.no_grad()
123
  def create_ghost_token(_feature_vector, _token_centroid, _pca_direction, target_distance, pca_weight):
@@ -174,38 +182,50 @@ def initialize_resources():
174
  def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp):
175
  global w_enc_dict, w_dec_dict
176
 
177
- if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
178
- w_enc, w_dec = load_sae_weights(selected_sae)
179
- if w_enc is None or w_dec is None:
180
- return f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection."
181
- w_enc_dict[selected_sae] = w_enc
182
- w_dec_dict[selected_sae] = w_dec
183
- else:
184
- w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae]
185
-
186
- token_centroid = torch.mean(token_embeddings, dim=0)
187
- feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor)
188
-
189
- if use_pca:
190
- pca_direction = perform_pca(token_embeddings)
191
- feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight)
192
-
193
- closest_tokens_with_values = find_closest_tokens(
194
- feature_vector, token_embeddings, tokenizer,
195
- top_k=500, num_exp=num_exp, denom_exp=denom_exp
196
- )
197
-
198
- token_list = [token for token, _ in closest_tokens_with_values]
199
- result = f"100 tokens whose embeddings produce the smallest ratio:\n\n"
200
- result += f"[{', '.join(repr(token) for token in token_list[:100])}]\n\n"
201
- result += "Top 500 list:\n"
202
- result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values])
203
-
204
- return result
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  def gradio_interface():
207
  with gr.Blocks() as demo:
208
- gr.Markdown("# Gemma-2B SAE Feature ExplorerX")
209
 
210
  with gr.Row():
211
  with gr.Column(scale=2):
 
114
  return feature_vector
115
 
116
  def perform_pca(_embeddings):
117
+ try:
118
+ pca = PCA(n_components=1)
119
+ embeddings_cpu = _embeddings.cpu().numpy()
120
+ pca.fit(embeddings_cpu)
121
+ pca_direction = torch.tensor(pca.components_[0], dtype=config.DTYPE, device=config.DEVICE)
122
+ return F.normalize(pca_direction, p=2, dim=0)
123
+ except Exception as e:
124
+ logger.error(f"Error in perform_pca: {str(e)}")
125
+ logger.error(traceback.format_exc())
126
+ raise RuntimeError(f"PCA calculation failed: {str(e)}")
127
+
128
+
129
 
130
  @torch.no_grad()
131
  def create_ghost_token(_feature_vector, _token_centroid, _pca_direction, target_distance, pca_weight):
 
182
  def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp):
183
  global w_enc_dict, w_dec_dict
184
 
185
+ try:
186
+ if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
187
+ w_enc, w_dec = load_sae_weights(selected_sae)
188
+ if w_enc is None or w_dec is None:
189
+ return f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection."
190
+ w_enc_dict[selected_sae] = w_enc
191
+ w_dec_dict[selected_sae] = w_dec
192
+ else:
193
+ w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae]
194
+
195
+ token_centroid = torch.mean(token_embeddings, dim=0)
196
+ feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor)
197
+
198
+ if use_pca:
199
+ logger.info("Performing PCA...")
200
+ try:
201
+ pca_direction = perform_pca(token_embeddings)
202
+ feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight)
203
+ logger.info("PCA completed successfully.")
204
+ except Exception as pca_error:
205
+ logger.error(f"Error during PCA: {str(pca_error)}")
206
+ return f"Error during PCA: {str(pca_error)}"
207
+
208
+ closest_tokens_with_values = find_closest_tokens(
209
+ feature_vector, token_embeddings, tokenizer,
210
+ top_k=500, num_exp=num_exp, denom_exp=denom_exp
211
+ )
212
+
213
+ token_list = [token for token, _ in closest_tokens_with_values]
214
+ result = f"100 tokens whose embeddings produce the smallest ratio:\n\n"
215
+ result += f"[{', '.join(repr(token) for token in token_list[:100])}]\n\n"
216
+ result += "Top 500 list:\n"
217
+ result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values])
218
+
219
+ return result
220
+ except Exception as e:
221
+ logger.error(f"Error in process_input: {str(e)}")
222
+ logger.error(traceback.format_exc())
223
+ return f"Error: {str(e)}"
224
+
225
 
226
  def gradio_interface():
227
  with gr.Blocks() as demo:
228
+ gr.Markdown("# Gemma-2B SAE Feature Explorer")
229
 
230
  with gr.Row():
231
  with gr.Column(scale=2):