Nu Appleblossom commited on
Commit
cc8d300
·
1 Parent(s): 32e7790

back to last promising version with treebuild crashlog trying to move text to interface with 4o help AGAIN2

Browse files
Files changed (1) hide show
  1. app.py +24 -17
app.py CHANGED
@@ -187,6 +187,11 @@ def update_token_embedding(model, token_id, new_embedding):
187
  new_embedding = new_embedding.to(model.get_input_embeddings().weight.device)
188
  model.get_input_embeddings().weight.data[token_id] = new_embedding
189
 
 
 
 
 
 
190
  def produce_next_token_ids(input_ids, model, topk, sub_token_id):
191
  input_ids = input_ids.to(model.device)
192
  with torch.no_grad():
@@ -445,23 +450,30 @@ def trim_tree(trim_cutoff, tree_data):
445
  def gradio_interface():
446
  def update_visibility(mode):
447
  if mode == "definition tree generation":
448
- return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
 
449
  else:
450
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
 
 
 
 
 
 
 
 
451
 
452
  with gr.Blocks() as demo:
453
  gr.Markdown("# Gemma-2B SAE Feature Explorer")
454
 
455
  with gr.Row():
456
  with gr.Column(scale=2):
457
- selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE")
 
458
  feature_number = gr.Number(label="Select feature number", minimum=0, maximum=16383, value=0)
459
 
460
- mode = gr.Radio(
461
- choices=["cosine distance token lists", "definition tree generation"],
462
- label="Select mode",
463
- value="cosine distance token lists"
464
- )
465
 
466
  weight_type = gr.Radio(["encoder", "decoder"], label="Select weight type for feature vector construction", value="encoder")
467
  use_token_centroid = gr.Checkbox(label="Use token centroid offset", value=True)
@@ -485,16 +497,12 @@ def gradio_interface():
485
 
486
  tree_data_state = gr.State()
487
 
488
- inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
 
489
 
490
- generate_btn.click(
491
- update_output,
492
- inputs=inputs,
493
- outputs=[output_text, output_image], # Now the text output will be displayed in the Textbox
494
- show_progress="full"
495
- )
496
 
497
- # other buttons and changes remain the same...
498
 
499
  return demo
500
 
@@ -502,7 +510,6 @@ def gradio_interface():
502
 
503
 
504
 
505
-
506
  if __name__ == "__main__":
507
  try:
508
  logger.info("Starting application initialization...")
 
187
  new_embedding = new_embedding.to(model.get_input_embeddings().weight.device)
188
  model.get_input_embeddings().weight.data[token_id] = new_embedding
189
 
190
+ @spaces.GPU
191
+ def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, progress=gr.Progress()):
192
+ # Call process_input to generate the output
193
+ return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False, progress=progress)
194
+
195
  def produce_next_token_ids(input_ids, model, topk, sub_token_id):
196
  input_ids = input_ids.to(model.device)
197
  with torch.no_grad():
 
450
  def gradio_interface():
451
  def update_visibility(mode):
452
  if mode == "definition tree generation":
453
+ return (gr.update(visible=True), gr.update(visible=True), gr.update(visible=True),
454
+ gr.update(visible=False), gr.update(visible=False))
455
  else:
456
+ return (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
457
+ gr.update(visible=True), gr.update(visible=True))
458
+
459
+ @spaces.GPU
460
+ def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor,
461
+ use_pca, pca_weight, num_exp, denom_exp, mode, progress=gr.Progress()):
462
+ # Call process_input to generate the output
463
+ return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor,
464
+ use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False, progress=progress)
465
 
466
  with gr.Blocks() as demo:
467
  gr.Markdown("# Gemma-2B SAE Feature Explorer")
468
 
469
  with gr.Row():
470
  with gr.Column(scale=2):
471
+ selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"],
472
+ label="Select SAE")
473
  feature_number = gr.Number(label="Select feature number", minimum=0, maximum=16383, value=0)
474
 
475
+ mode = gr.Radio(choices=["cosine distance token lists", "definition tree generation"],
476
+ label="Select mode", value="cosine distance token lists")
 
 
 
477
 
478
  weight_type = gr.Radio(["encoder", "decoder"], label="Select weight type for feature vector construction", value="encoder")
479
  use_token_centroid = gr.Checkbox(label="Use token centroid offset", value=True)
 
497
 
498
  tree_data_state = gr.State()
499
 
500
+ inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor,
501
+ use_pca, pca_weight, num_exp, denom_exp, mode]
502
 
503
+ generate_btn.click(update_output, inputs=inputs, outputs=[output_text, output_image], show_progress="full")
 
 
 
 
 
504
 
505
+ # Add other button functionality as needed...
506
 
507
  return demo
508
 
 
510
 
511
 
512
 
 
513
  if __name__ == "__main__":
514
  try:
515
  logger.info("Starting application initialization...")