Nu Appleblossom commited on
Commit
6191828
1 Parent(s): 93d356c

app.py trying to fix lists functionality

Browse files
Files changed (1) hide show
  1. app.py +49 -79
app.py CHANGED
@@ -341,12 +341,13 @@ def initialize_resources():
341
 
342
  logger.info("Resources initialized successfully.")
343
 
 
344
  @spaces.GPU
345
- def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
346
  global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings
347
 
348
  try:
349
- logger.info("Processing input: SAE={}, feature_number={}, mode={}".format(selected_sae, feature_number, mode))
350
 
351
  # Load the SAE weights if they are not already loaded
352
  if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
@@ -373,16 +374,20 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
373
  top_k=500, num_exp=num_exp, denom_exp=denom_exp
374
  )
375
 
376
- token_list = [token for token, _ in closest_tokens_with_values]
377
- result = f"100 tokens whose embeddings produce the smallest ratio:\n\n"
378
- result += f"[{', '.join(repr(token) for token in token_list[:100])}]\n\n"
379
- result += "Top 500 list:\n"
380
- result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values])
381
-
382
- logger.info("Returning result for cosine distance token list")
383
- return result, None # Return the result and no image
 
 
 
 
 
384
 
385
- # In case the mode is not "cosine distance token lists", return a default message
386
  return "Mode not recognized or not implemented in this step.", None
387
 
388
  except Exception as e:
@@ -391,6 +396,8 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
391
 
392
 
393
 
 
 
394
  def trim_tree(trim_cutoff, tree_data):
395
  max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
396
  trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff))
@@ -398,80 +405,28 @@ def trim_tree(trim_cutoff, tree_data):
398
 
399
 
400
 
 
401
  def gradio_interface():
402
  def update_visibility(mode):
403
  if mode == "definition tree generation":
404
  return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
405
  else:
406
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
407
 
408
  def update_neuronpedia(selected_sae, feature_number):
409
  layer_number = int(selected_sae.split()[-1])
410
  url = get_neuronpedia_url(layer_number, feature_number)
411
  return f'<iframe src="{url}" width="100%" height="300px"></iframe>'
412
 
413
-
414
  @spaces.GPU
415
  def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
416
- global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings
417
-
418
- try:
419
- if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
420
- w_enc, w_dec = load_sae_weights(selected_sae)
421
- if w_enc is None or w_dec is None:
422
- return f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection.", None, None
423
- w_enc_dict[selected_sae] = w_enc
424
- w_dec_dict[selected_sae] = w_dec
425
- else:
426
- w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae]
427
-
428
- token_centroid = torch.mean(token_embeddings, dim=0)
429
- feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor)
430
-
431
- if use_pca:
432
- pca_direction = perform_pca(token_embeddings)
433
- feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight)
434
-
435
- if mode == "cosine distance token lists":
436
- closest_tokens_with_values = find_closest_tokens(
437
- feature_vector, token_embeddings, tokenizer,
438
- top_k=500, num_exp=num_exp, denom_exp=denom_exp
439
- )
440
-
441
- token_list = [token for token, _ in closest_tokens_with_values]
442
- result = f"100 tokens whose embeddings produce the smallest ratio:\n\n"
443
- result += f"[{', '.join(repr(token) for token in token_list[:100])}]\n\n"
444
- result += "Top 500 list:\n"
445
- result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values])
446
-
447
- return result, None, None # Return the result, no image, and no tree data
448
- elif mode == "definition tree generation":
449
- base_prompt = f'A typical definition of "{tokenizer.decode([config.SUB_TOKEN_ID], skip_special_tokens=True)}" would be "'
450
- tree_generator = generate_definition_tree(base_prompt, feature_vector, model, tokenizer, config)
451
-
452
- # Generate the tree
453
- tree_text = ""
454
- tree_dict = None
455
- for item in tree_generator:
456
- if isinstance(item, str):
457
- tree_text += item
458
- yield tree_text, None, None # Yield the updated text, no image, and no tree data
459
- else:
460
- tree_dict = item
461
-
462
- # Generate the tree visualization
463
- max_weight, min_weight = find_max_min_cumulative_weight(tree_dict)
464
- tree_image = create_tree_diagram(tree_dict, config, max_weight, min_weight)
465
-
466
- return tree_text, tree_image, tree_dict # Return the final text, tree image, and tree data
467
-
468
- except Exception as e:
469
- logger.error(f"Error in update_output: {str(e)}")
470
- logger.error(traceback.format_exc())
471
- return f"Error: {str(e)}\nPlease check the logs for more details.", None, None
472
-
473
-
474
 
 
 
 
 
475
 
476
  def trim_tree(trim_cutoff, tree_data):
477
  if tree_data is None:
@@ -481,8 +436,8 @@ def gradio_interface():
481
  return trimmed_tree_image
482
 
483
  with gr.Blocks() as demo:
484
- gr.Markdown("# Gemma-2B SAE Feature Explorer (gradual3)")
485
-
486
  with gr.Row():
487
  with gr.Column(scale=2):
488
  selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE")
@@ -508,6 +463,9 @@ def gradio_interface():
508
  output_text = gr.Textbox(label="Output", lines=20)
509
  output_image = gr.Image(label="Tree Diagram", visible=False)
510
 
 
 
 
511
  trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability", visible=False)
512
  trim_btn = gr.Button("Trim Tree", visible=False)
513
 
@@ -515,28 +473,40 @@ def gradio_interface():
515
  neuronpedia_html = gr.HTML(label="Neuronpedia")
516
 
517
  inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
518
-
519
 
520
  generate_btn.click(
521
- process_input,
522
  inputs=inputs,
523
- outputs=[output_text, output_image], # Ensure these match your components
524
  show_progress="full"
525
  )
526
 
527
-
528
- trim_btn.click(trim_tree, inputs=[trim_slider, tree_data_state], outputs=[output_image])
 
 
 
 
529
 
 
 
530
  mode.change(update_visibility, inputs=[mode], outputs=[output_image, trim_slider, trim_btn])
531
 
532
  selected_sae.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
533
  feature_number.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
534
 
535
-
 
 
 
 
536
 
537
  return demo
538
 
539
 
 
 
 
540
  if __name__ == "__main__":
541
  try:
542
  logger.info("Starting application initialization...")
 
341
 
342
  logger.info("Resources initialized successfully.")
343
 
344
+
345
  @spaces.GPU
346
+ def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False):
347
  global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings
348
 
349
  try:
350
+ logger.info(f"Processing input: SAE={selected_sae}, feature_number={feature_number}, mode={mode}")
351
 
352
  # Load the SAE weights if they are not already loaded
353
  if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
 
374
  top_k=500, num_exp=num_exp, denom_exp=denom_exp
375
  )
376
 
377
+ if top_500:
378
+ # Generate the top 500 list
379
+ result = "Top 500 list:\n"
380
+ result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values])
381
+ logger.info("Returning top 500 list")
382
+ return result, None
383
+ else:
384
+ # Generate the top 100 list
385
+ token_list = [token for token, _ in closest_tokens_with_values[:100]]
386
+ result = f"100 tokens whose embeddings produce the smallest ratio:\n\n"
387
+ result += f"[{', '.join(repr(token) for token in token_list)}]\n"
388
+ logger.info("Returning top 100 tokens")
389
+ return result, None
390
 
 
391
  return "Mode not recognized or not implemented in this step.", None
392
 
393
  except Exception as e:
 
396
 
397
 
398
 
399
+
400
+
401
  def trim_tree(trim_cutoff, tree_data):
402
  max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
403
  trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff))
 
405
 
406
 
407
 
408
+
409
  def gradio_interface():
410
  def update_visibility(mode):
411
  if mode == "definition tree generation":
412
  return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
413
  else:
414
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible(False))
415
 
416
  def update_neuronpedia(selected_sae, feature_number):
417
  layer_number = int(selected_sae.split()[-1])
418
  url = get_neuronpedia_url(layer_number, feature_number)
419
  return f'<iframe src="{url}" width="100%" height="300px"></iframe>'
420
 
 
421
  @spaces.GPU
422
  def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
423
+ # Call process_input without generating the top 500 list initially
424
+ return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
+ @spaces.GPU
427
+ def generate_top_500(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
428
+ # Call process_input with top_500=True to generate the full list
429
+ return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=True)
430
 
431
  def trim_tree(trim_cutoff, tree_data):
432
  if tree_data is None:
 
436
  return trimmed_tree_image
437
 
438
  with gr.Blocks() as demo:
439
+ gr.Markdown("# Gemma-2B SAE Feature Explorer (almost there?)")
440
+
441
  with gr.Row():
442
  with gr.Column(scale=2):
443
  selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE")
 
463
  output_text = gr.Textbox(label="Output", lines=20)
464
  output_image = gr.Image(label="Tree Diagram", visible=False)
465
 
466
+ generate_top_500_btn = gr.Button("Generate Top 500 Tokens and Power Ratios", visible=False)
467
+ output_500_text = gr.Textbox(label="Top 500 Output", lines=20, visible=False)
468
+
469
  trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability", visible=False)
470
  trim_btn = gr.Button("Trim Tree", visible=False)
471
 
 
473
  neuronpedia_html = gr.HTML(label="Neuronpedia")
474
 
475
  inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
 
476
 
477
  generate_btn.click(
478
+ update_output,
479
  inputs=inputs,
480
+ outputs=[output_text, output_image],
481
  show_progress="full"
482
  )
483
 
484
+ generate_top_500_btn.click(
485
+ generate_top_500,
486
+ inputs=inputs,
487
+ outputs=[output_500_text],
488
+ show_progress="full"
489
+ )
490
 
491
+ trim_btn.click(trim_tree, inputs=[trim_slider, tree_data_state], outputs=[output_image])
492
+
493
  mode.change(update_visibility, inputs=[mode], outputs=[output_image, trim_slider, trim_btn])
494
 
495
  selected_sae.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
496
  feature_number.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
497
 
498
+ output_text.change(
499
+ lambda text: (gr.update(visible=True), gr.update(visible=True)) if "100 tokens" in text else (gr.update(visible(False)), gr.update(visible(False))),
500
+ inputs=[output_text],
501
+ outputs=[generate_top_500_btn, output_500_text]
502
+ )
503
 
504
  return demo
505
 
506
 
507
+
508
+
509
+
510
  if __name__ == "__main__":
511
  try:
512
  logger.info("Starting application initialization...")