Nu Appleblossom commited on
Commit
6046d47
·
1 Parent(s): 10fc5c3

app.py refactored

Browse files
Files changed (1) hide show
  1. app.py +17 -3
app.py CHANGED
@@ -420,6 +420,7 @@ def gradio_interface():
420
  url = get_neuronpedia_url(layer_number, feature_number)
421
  return f'<iframe src="{url}" width="100%" height="300px"></iframe>'
422
 
 
423
  @spaces.GPU
424
  def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
425
  global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings
@@ -479,6 +480,9 @@ def gradio_interface():
479
  logger.error(traceback.format_exc())
480
  return f"Error: {str(e)}\nPlease check the logs for more details.", None, None
481
 
 
 
 
482
  def trim_tree(trim_cutoff, tree_data):
483
  if tree_data is None:
484
  return None
@@ -522,14 +526,24 @@ def gradio_interface():
522
 
523
  inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
524
 
525
- generate_btn.click(update_output, inputs=inputs, outputs=[output_text, output_image, tree_data_state])
 
 
 
 
 
 
 
 
526
  trim_btn.click(trim_tree, inputs=[trim_slider, tree_data_state], outputs=[output_image])
527
 
528
  mode.change(update_visibility, inputs=[mode], outputs=[output_image, trim_slider, trim_btn])
529
-
530
  selected_sae.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
531
  feature_number.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
532
-
 
 
533
  return demo
534
 
535
 
 
420
  url = get_neuronpedia_url(layer_number, feature_number)
421
  return f'<iframe src="{url}" width="100%" height="300px"></iframe>'
422
 
423
+
424
  @spaces.GPU
425
  def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
426
  global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings
 
480
  logger.error(traceback.format_exc())
481
  return f"Error: {str(e)}\nPlease check the logs for more details.", None, None
482
 
483
+
484
+
485
+
486
  def trim_tree(trim_cutoff, tree_data):
487
  if tree_data is None:
488
  return None
 
526
 
527
  inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
528
 
529
+
530
+ generate_btn.click(
531
+ update_output,
532
+ inputs=inputs,
533
+ outputs=[output_text, output_image, tree_data_state],
534
+ show_progress="full"
535
+ )
536
+
537
+
538
  trim_btn.click(trim_tree, inputs=[trim_slider, tree_data_state], outputs=[output_image])
539
 
540
  mode.change(update_visibility, inputs=[mode], outputs=[output_image, trim_slider, trim_btn])
541
+
542
  selected_sae.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
543
  feature_number.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
544
+
545
+
546
+
547
  return demo
548
 
549