Nu Appleblossom commited on
Commit
b988e93
1 Parent(s): acdf4b5

app.py refactored

Browse files
Files changed (1) hide show
  1. app.py +62 -25
app.py CHANGED
@@ -403,6 +403,65 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
403
 
404
 
405
  def gradio_interface():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  with gr.Blocks() as demo:
407
  gr.Markdown("# Gemma-2B SAE Feature Explorer")
408
 
@@ -434,34 +493,12 @@ def gradio_interface():
434
  trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability")
435
  trim_btn = gr.Button("Trim Tree")
436
 
437
- def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
438
- result = process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode)
439
-
440
- if mode == "definition tree generation":
441
- tree_text = ""
442
- tree_image = None
443
- tree_data = None
444
- for i, item in enumerate(result):
445
- progress(i / 100) # Assuming max 100 iterations, adjust as needed
446
- if isinstance(item, str):
447
- tree_text += item
448
- yield tree_text, tree_image, tree_data
449
- else:
450
- tree_data = item
451
- tree_image = create_tree_diagram(tree_data, config, *find_max_min_cumulative_weight(tree_data))
452
- yield tree_text, tree_image, tree_data
453
- else:
454
- yield result, None, None
455
-
456
- def trim_tree(trim_cutoff, tree_data):
457
- max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
458
- trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff))
459
- return trimmed_tree_image
460
 
461
  inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
462
 
463
- generate_btn.click(update_output, inputs=inputs, outputs=[output_text, output_image, gr.State(key="tree_data")])
464
- trim_btn.click(trim_tree, inputs=[trim_slider, gr.State(key="tree_data")], outputs=[output_image])
465
 
466
  return demo
467
 
 
403
 
404
 
405
  def gradio_interface():
406
+ @spaces.GPU
407
+ def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
408
+ global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings
409
+
410
+ try:
411
+ if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
412
+ w_enc, w_dec = load_sae_weights(selected_sae)
413
+ if w_enc is None or w_dec is None:
414
+ return f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection.", None, None
415
+ w_enc_dict[selected_sae] = w_enc
416
+ w_dec_dict[selected_sae] = w_dec
417
+ else:
418
+ w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae]
419
+
420
+ token_centroid = torch.mean(token_embeddings, dim=0)
421
+ feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor)
422
+
423
+ if use_pca:
424
+ pca_direction = perform_pca(token_embeddings)
425
+ feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight)
426
+
427
+ if mode == "cosine distance token lists":
428
+ closest_tokens_with_values = find_closest_tokens(
429
+ feature_vector, token_embeddings, tokenizer,
430
+ top_k=500, num_exp=num_exp, denom_exp=denom_exp
431
+ )
432
+
433
+ token_list = [token for token, _ in closest_tokens_with_values]
434
+ result = f"100 tokens whose embeddings produce the smallest ratio:\n\n"
435
+ result += f"[{', '.join(repr(token) for token in token_list[:100])}]\n\n"
436
+ result += "Top 500 list:\n"
437
+ result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values])
438
+
439
+ return result, None, None # Return the result, no image, and no tree data
440
+ elif mode == "definition tree generation":
441
+ base_prompt = f'A typical definition of "{tokenizer.decode([config.SUB_TOKEN_ID], skip_special_tokens=True)}" would be "'
442
+ tree_generator = generate_definition_tree(base_prompt, feature_vector, model, tokenizer, config)
443
+
444
+ # Generate the tree
445
+ tree_text = ""
446
+ tree_dict = None
447
+ for item in tree_generator:
448
+ if isinstance(item, str):
449
+ tree_text += item
450
+ yield tree_text, None, None # Yield the updated text, no image, and no tree data
451
+ else:
452
+ tree_dict = item
453
+
454
+ # Generate the tree visualization
455
+ max_weight, min_weight = find_max_min_cumulative_weight(tree_dict)
456
+ tree_image = create_tree_diagram(tree_dict, config, max_weight, min_weight)
457
+
458
+ return tree_text, tree_image, tree_dict # Return the final text, tree image, and tree data
459
+
460
+ except Exception as e:
461
+ logger.error(f"Error in update_output: {str(e)}")
462
+ logger.error(traceback.format_exc())
463
+ return f"Error: {str(e)}\nPlease check the logs for more details.", None, None
464
+
465
  with gr.Blocks() as demo:
466
  gr.Markdown("# Gemma-2B SAE Feature Explorer")
467
 
 
493
  trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability")
494
  trim_btn = gr.Button("Trim Tree")
495
 
496
+ tree_data_state = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
  inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
499
 
500
+ generate_btn.click(update_output, inputs=inputs, outputs=[output_text, output_image, tree_data_state])
501
+ trim_btn.click(trim_tree, inputs=[trim_slider, tree_data_state], outputs=[output_image])
502
 
503
  return demo
504