Spaces:
Sleeping
Sleeping
Nu Appleblossom
commited on
Commit
•
b988e93
1
Parent(s):
acdf4b5
app.py refactored
Browse files
app.py
CHANGED
@@ -403,6 +403,65 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
|
|
403 |
|
404 |
|
405 |
def gradio_interface():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
with gr.Blocks() as demo:
|
407 |
gr.Markdown("# Gemma-2B SAE Feature Explorer")
|
408 |
|
@@ -434,34 +493,12 @@ def gradio_interface():
|
|
434 |
trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability")
|
435 |
trim_btn = gr.Button("Trim Tree")
|
436 |
|
437 |
-
|
438 |
-
result = process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode)
|
439 |
-
|
440 |
-
if mode == "definition tree generation":
|
441 |
-
tree_text = ""
|
442 |
-
tree_image = None
|
443 |
-
tree_data = None
|
444 |
-
for i, item in enumerate(result):
|
445 |
-
progress(i / 100) # Assuming max 100 iterations, adjust as needed
|
446 |
-
if isinstance(item, str):
|
447 |
-
tree_text += item
|
448 |
-
yield tree_text, tree_image, tree_data
|
449 |
-
else:
|
450 |
-
tree_data = item
|
451 |
-
tree_image = create_tree_diagram(tree_data, config, *find_max_min_cumulative_weight(tree_data))
|
452 |
-
yield tree_text, tree_image, tree_data
|
453 |
-
else:
|
454 |
-
yield result, None, None
|
455 |
-
|
456 |
-
def trim_tree(trim_cutoff, tree_data):
|
457 |
-
max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
|
458 |
-
trimmed_tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight, trim_cutoff=float(trim_cutoff))
|
459 |
-
return trimmed_tree_image
|
460 |
|
461 |
inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
|
462 |
|
463 |
-
generate_btn.click(update_output, inputs=inputs, outputs=[output_text, output_image,
|
464 |
-
trim_btn.click(trim_tree, inputs=[trim_slider,
|
465 |
|
466 |
return demo
|
467 |
|
|
|
403 |
|
404 |
|
405 |
def gradio_interface():
|
406 |
+
@spaces.GPU
|
407 |
+
def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
|
408 |
+
global w_enc_dict, w_dec_dict, model, tokenizer, token_embeddings
|
409 |
+
|
410 |
+
try:
|
411 |
+
if selected_sae not in w_enc_dict or selected_sae not in w_dec_dict:
|
412 |
+
w_enc, w_dec = load_sae_weights(selected_sae)
|
413 |
+
if w_enc is None or w_dec is None:
|
414 |
+
return f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection.", None, None
|
415 |
+
w_enc_dict[selected_sae] = w_enc
|
416 |
+
w_dec_dict[selected_sae] = w_dec
|
417 |
+
else:
|
418 |
+
w_enc, w_dec = w_enc_dict[selected_sae], w_dec_dict[selected_sae]
|
419 |
+
|
420 |
+
token_centroid = torch.mean(token_embeddings, dim=0)
|
421 |
+
feature_vector = create_feature_vector(w_enc, w_dec, int(feature_number), weight_type, token_centroid, use_token_centroid, scaling_factor)
|
422 |
+
|
423 |
+
if use_pca:
|
424 |
+
pca_direction = perform_pca(token_embeddings)
|
425 |
+
feature_vector = create_ghost_token(feature_vector, token_centroid, pca_direction, scaling_factor, pca_weight)
|
426 |
+
|
427 |
+
if mode == "cosine distance token lists":
|
428 |
+
closest_tokens_with_values = find_closest_tokens(
|
429 |
+
feature_vector, token_embeddings, tokenizer,
|
430 |
+
top_k=500, num_exp=num_exp, denom_exp=denom_exp
|
431 |
+
)
|
432 |
+
|
433 |
+
token_list = [token for token, _ in closest_tokens_with_values]
|
434 |
+
result = f"100 tokens whose embeddings produce the smallest ratio:\n\n"
|
435 |
+
result += f"[{', '.join(repr(token) for token in token_list[:100])}]\n\n"
|
436 |
+
result += "Top 500 list:\n"
|
437 |
+
result += "\n".join([f"{token!r}: {value:.4f}" for token, value in closest_tokens_with_values])
|
438 |
+
|
439 |
+
return result, None, None # Return the result, no image, and no tree data
|
440 |
+
elif mode == "definition tree generation":
|
441 |
+
base_prompt = f'A typical definition of "{tokenizer.decode([config.SUB_TOKEN_ID], skip_special_tokens=True)}" would be "'
|
442 |
+
tree_generator = generate_definition_tree(base_prompt, feature_vector, model, tokenizer, config)
|
443 |
+
|
444 |
+
# Generate the tree
|
445 |
+
tree_text = ""
|
446 |
+
tree_dict = None
|
447 |
+
for item in tree_generator:
|
448 |
+
if isinstance(item, str):
|
449 |
+
tree_text += item
|
450 |
+
yield tree_text, None, None # Yield the updated text, no image, and no tree data
|
451 |
+
else:
|
452 |
+
tree_dict = item
|
453 |
+
|
454 |
+
# Generate the tree visualization
|
455 |
+
max_weight, min_weight = find_max_min_cumulative_weight(tree_dict)
|
456 |
+
tree_image = create_tree_diagram(tree_dict, config, max_weight, min_weight)
|
457 |
+
|
458 |
+
return tree_text, tree_image, tree_dict # Return the final text, tree image, and tree data
|
459 |
+
|
460 |
+
except Exception as e:
|
461 |
+
logger.error(f"Error in update_output: {str(e)}")
|
462 |
+
logger.error(traceback.format_exc())
|
463 |
+
return f"Error: {str(e)}\nPlease check the logs for more details.", None, None
|
464 |
+
|
465 |
with gr.Blocks() as demo:
|
466 |
gr.Markdown("# Gemma-2B SAE Feature Explorer")
|
467 |
|
|
|
493 |
trim_slider = gr.Slider(minimum=0.00001, maximum=0.1, value=0.00001, label="Trim cutoff for cumulative probability")
|
494 |
trim_btn = gr.Button("Trim Tree")
|
495 |
|
496 |
+
tree_data_state = gr.State()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
|
498 |
inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode]
|
499 |
|
500 |
+
generate_btn.click(update_output, inputs=inputs, outputs=[output_text, output_image, tree_data_state])
|
501 |
+
trim_btn.click(trim_tree, inputs=[trim_slider, tree_data_state], outputs=[output_image])
|
502 |
|
503 |
return demo
|
504 |
|