Spaces:
Sleeping
Sleeping
Nu Appleblossom
commited on
Commit
•
6191828
1
Parent(s):
93d356c
app.py trying to fix lists functionality
Browse files
app.py
CHANGED
@@ -341,12 +341,13 @@ def initialize_resources():
|
|
341 |
|
342 |
logger.info("Resources initialized successfully.")
|
343 |
|
|
|
344 |
@spaces.GPU
|
345 |
-
def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
|
346 |
global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings
|
347 |
|
348 |
try:
|
349 |
-
logger.info("Processing input: SAE={}, feature_number={}, mode={}"
|
350 |
|
351 |
# Load the SAE weights if they are not already loaded
|
352 |
if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
|
@@ -373,16 +374,20 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
|
|
373 |
top_k=500, num_exp=num_exp, denom_exp=denom_exp
|
374 |
)
|
375 |
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
-
# In case the mode is not "cosine distance token lists", return a default message
|
386 |
return "Mode not recognized or not implemented in this step.", None
|
387 |
|
388 |
except Exception as e:
|
@@ -391,6 +396,8 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
|
|
391 |
|
392 |
|
393 |
|
|
|
|
|
394 |
def trim_tree(trim_cutoff, tree_data):
|
395 |
max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
|
396 |
trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff))
|
@@ -398,80 +405,28 @@ def trim_tree(trim_cutoff, tree_data):
|
|
398 |
|
399 |
|
400 |
|
|
|
401 |
def gradio_interface():
|
402 |
def update_visibility(mode):
|
403 |
if mode == "definition tree generation":
|
404 |
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
405 |
else:
|
406 |
-
return gr.update(visible=False), gr.update(visible=False), gr.update(visible
|
407 |
|
408 |
def update_neuronpedia(selected_sae, feature_number):
|
409 |
layer_number = int(selected_sae.split()[-1])
|
410 |
url = get_neuronpedia_url(layer_number, feature_number)
|
411 |
return f'<iframe src="{url}" width="100%" height="300px"></iframe>'
|
412 |
|
413 |
-
|
414 |
@spaces.GPU
|
415 |
def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
|
416 |
-
|
417 |
-
|
418 |
-
try:
|
419 |
-
if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
|
420 |
-
w_enc, w_dec = load_sae_weights(selected_sae)
|
421 |
-
if w_enc is None or w_dec is None:
|
422 |
-
return f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection.", None, None
|
423 |
-
w_enc_dict[selected_sae] = w_enc
|
424 |
-
w_dec_dict[selected_sae] = w_dec
|
425 |
-
else:
|
426 |
-
w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae]
|
427 |
-
|
428 |
-
token_centroid = torch.mean(token_embeddings, dim=0)
|
429 |
-
feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor)
|
430 |
-
|
431 |
-
if use_pca:
|
432 |
-
pca_direction = perform_pca(token_embeddings)
|
433 |
-
feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight)
|
434 |
-
|
435 |
-
if mode == "cosine distance token lists":
|
436 |
-
closest_tokens_with_values = find_closest_tokens(
|
437 |
-
feature_vector, token_embeddings, tokenizer,
|
438 |
-
top_k=500, num_exp=num_exp, denom_exp=denom_exp
|
439 |
-
)
|
440 |
-
|
441 |
-
token_list = [token for token, _ in closest_tokens_with_values]
|
442 |
-
result = f"100 tokens whose embeddings produce the smallest ratio:\n\n"
|
443 |
-
result += f"[{', '.join(repr(token) for token in token_list[:100])}]\n\n"
|
444 |
-
result += "Top 500 list:\n"
|
445 |
-
result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values])
|
446 |
-
|
447 |
-
return result, None, None # Return the result, no image, and no tree data
|
448 |
-
elif mode == "definition tree generation":
|
449 |
-
base_prompt = f'A typical definition of "{tokenizer.decode([config.SUB_TOKEN_ID], skip_special_tokens=True)}" would be "'
|
450 |
-
tree_generator = generate_definition_tree(base_prompt, feature_vector, model, tokenizer, config)
|
451 |
-
|
452 |
-
# Generate the tree
|
453 |
-
tree_text = ""
|
454 |
-
tree_dict = None
|
455 |
-
for item in tree_generator:
|
456 |
-
if isinstance(item, str):
|
457 |
-
tree_text += item
|
458 |
-
yield tree_text, None, None # Yield the updated text, no image, and no tree data
|
459 |
-
else:
|
460 |
-
tree_dict = item
|
461 |
-
|
462 |
-
# Generate the tree visualization
|
463 |
-
max_weight, min_weight = find_max_min_cumulative_weight(tree_dict)
|
464 |
-
tree_image = create_tree_diagram(tree_dict, config, max_weight, min_weight)
|
465 |
-
|
466 |
-
return tree_text, tree_image, tree_dict # Return the final text, tree image, and tree data
|
467 |
-
|
468 |
-
except Exception as e:
|
469 |
-
logger.error(f"Error in update_output: {str(e)}")
|
470 |
-
logger.error(traceback.format_exc())
|
471 |
-
return f"Error: {str(e)}\nPlease check the logs for more details.", None, None
|
472 |
-
|
473 |
-
|
474 |
|
|
|
|
|
|
|
|
|
475 |
|
476 |
def trim_tree(trim_cutoff, tree_data):
|
477 |
if tree_data is None:
|
@@ -481,8 +436,8 @@ def gradio_interface():
|
|
481 |
return trimmed_tree_image
|
482 |
|
483 |
with gr.Blocks() as demo:
|
484 |
-
gr.Markdown("# Gemma-2B SAE Feature Explorer (
|
485 |
-
|
486 |
with gr.Row():
|
487 |
with gr.Column(scale=2):
|
488 |
selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE")
|
@@ -508,6 +463,9 @@ def gradio_interface():
|
|
508 |
output_text = gr.Textbox(label="Output", lines=20)
|
509 |
output_image = gr.Image(label="Tree Diagram", visible=False)
|
510 |
|
|
|
|
|
|
|
511 |
trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability", visible=False)
|
512 |
trim_btn = gr.Button("Trim Tree", visible=False)
|
513 |
|
@@ -515,28 +473,40 @@ def gradio_interface():
|
|
515 |
neuronpedia_html = gr.HTML(label="Neuronpedia")
|
516 |
|
517 |
inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
|
518 |
-
|
519 |
|
520 |
generate_btn.click(
|
521 |
-
|
522 |
inputs=inputs,
|
523 |
-
outputs=[output_text, output_image],
|
524 |
show_progress="full"
|
525 |
)
|
526 |
|
527 |
-
|
528 |
-
|
|
|
|
|
|
|
|
|
529 |
|
|
|
|
|
530 |
mode.change(update_visibility, inputs=[mode], outputs=[output_image, trim_slider, trim_btn])
|
531 |
|
532 |
selected_sae.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
|
533 |
feature_number.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
|
534 |
|
535 |
-
|
|
|
|
|
|
|
|
|
536 |
|
537 |
return demo
|
538 |
|
539 |
|
|
|
|
|
|
|
540 |
if __name__ == "__main__":
|
541 |
try:
|
542 |
logger.info("Starting application initialization...")
|
|
|
341 |
|
342 |
logger.info("Resources initialized successfully.")
|
343 |
|
344 |
+
|
345 |
@spaces.GPU
|
346 |
+
def process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False):
|
347 |
global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings
|
348 |
|
349 |
try:
|
350 |
+
logger.info(f"Processing input: SAE={selected_sae}, feature_number={feature_number}, mode={mode}")
|
351 |
|
352 |
# Load the SAE weights if they are not already loaded
|
353 |
if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
|
|
|
374 |
top_k=500, num_exp=num_exp, denom_exp=denom_exp
|
375 |
)
|
376 |
|
377 |
+
if top_500:
|
378 |
+
# Generate the top 500 list
|
379 |
+
result = "Top 500 list:\n"
|
380 |
+
result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values])
|
381 |
+
logger.info("Returning top 500 list")
|
382 |
+
return result, None
|
383 |
+
else:
|
384 |
+
# Generate the top 100 list
|
385 |
+
token_list = [token for token, _ in closest_tokens_with_values[:100]]
|
386 |
+
result = f"100 tokens whose embeddings produce the smallest ratio:\n\n"
|
387 |
+
result += f"[{', '.join(repr(token) for token in token_list)}]\n"
|
388 |
+
logger.info("Returning top 100 tokens")
|
389 |
+
return result, None
|
390 |
|
|
|
391 |
return "Mode not recognized or not implemented in this step.", None
|
392 |
|
393 |
except Exception as e:
|
|
|
396 |
|
397 |
|
398 |
|
399 |
+
|
400 |
+
|
401 |
def trim_tree(trim_cutoff, tree_data):
|
402 |
max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
|
403 |
trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff))
|
|
|
405 |
|
406 |
|
407 |
|
408 |
+
|
409 |
def gradio_interface():
|
410 |
def update_visibility(mode):
|
411 |
if mode == "definition tree generation":
|
412 |
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
413 |
else:
|
414 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible(False))
|
415 |
|
416 |
def update_neuronpedia(selected_sae, feature_number):
|
417 |
layer_number = int(selected_sae.split()[-1])
|
418 |
url = get_neuronpedia_url(layer_number, feature_number)
|
419 |
return f'<iframe src="{url}" width="100%" height="300px"></iframe>'
|
420 |
|
|
|
421 |
@spaces.GPU
|
422 |
def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
|
423 |
+
# Call process_input without generating the top 500 list initially
|
424 |
+
return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
+
@spaces.GPU
|
427 |
+
def generate_top_500(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
|
428 |
+
# Call process_input with top_500=True to generate the full list
|
429 |
+
return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=True)
|
430 |
|
431 |
def trim_tree(trim_cutoff, tree_data):
|
432 |
if tree_data is None:
|
|
|
436 |
return trimmed_tree_image
|
437 |
|
438 |
with gr.Blocks() as demo:
|
439 |
+
gr.Markdown("# Gemma-2B SAE Feature Explorer (almost there?)")
|
440 |
+
|
441 |
with gr.Row():
|
442 |
with gr.Column(scale=2):
|
443 |
selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"], label="Select SAE")
|
|
|
463 |
output_text = gr.Textbox(label="Output", lines=20)
|
464 |
output_image = gr.Image(label="Tree Diagram", visible=False)
|
465 |
|
466 |
+
generate_top_500_btn = gr.Button("Generate Top 500 Tokens and Power Ratios", visible=False)
|
467 |
+
output_500_text = gr.Textbox(label="Top 500 Output", lines=20, visible=False)
|
468 |
+
|
469 |
trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability", visible=False)
|
470 |
trim_btn = gr.Button("Trim Tree", visible=False)
|
471 |
|
|
|
473 |
neuronpedia_html = gr.HTML(label="Neuronpedia")
|
474 |
|
475 |
inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
|
|
|
476 |
|
477 |
generate_btn.click(
|
478 |
+
update_output,
|
479 |
inputs=inputs,
|
480 |
+
outputs=[output_text, output_image],
|
481 |
show_progress="full"
|
482 |
)
|
483 |
|
484 |
+
generate_top_500_btn.click(
|
485 |
+
generate_top_500,
|
486 |
+
inputs=inputs,
|
487 |
+
outputs=[output_500_text],
|
488 |
+
show_progress="full"
|
489 |
+
)
|
490 |
|
491 |
+
trim_btn.click(trim_tree, inputs=[trim_slider, tree_data_state], outputs=[output_image])
|
492 |
+
|
493 |
mode.change(update_visibility, inputs=[mode], outputs=[output_image, trim_slider, trim_btn])
|
494 |
|
495 |
selected_sae.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
|
496 |
feature_number.change(update_neuronpedia, inputs=[selected_sae, feature_number], outputs=[neuronpedia_html])
|
497 |
|
498 |
+
output_text.change(
|
499 |
+
lambda text: (gr.update(visible=True), gr.update(visible=True)) if "100 tokens" in text else (gr.update(visible(False)), gr.update(visible(False))),
|
500 |
+
inputs=[output_text],
|
501 |
+
outputs=[generate_top_500_btn, output_500_text]
|
502 |
+
)
|
503 |
|
504 |
return demo
|
505 |
|
506 |
|
507 |
+
|
508 |
+
|
509 |
+
|
510 |
if __name__ == "__main__":
|
511 |
try:
|
512 |
logger.info("Starting application initialization...")
|