Spaces:
Sleeping
Sleeping
Nu Appleblossom
commited on
Commit
·
cc8d300
1
Parent(s):
32e7790
back to last promising version with treebuild crashlog trying to move text to interface with 4o help AGAIN2
Browse files
app.py
CHANGED
@@ -187,6 +187,11 @@ def update_token_embedding(model, token_id, new_embedding):
|
|
187 |
new_embedding = new_embedding.to(model.get_input_embeddings().weight.device)
|
188 |
model.get_input_embeddings().weight.data[token_id] = new_embedding
|
189 |
|
|
|
|
|
|
|
|
|
|
|
190 |
def produce_next_token_ids(input_ids, model, topk, sub_token_id):
|
191 |
input_ids = input_ids.to(model.device)
|
192 |
with torch.no_grad():
|
@@ -445,23 +450,30 @@ def trim_tree(trim_cutoff, tree_data):
|
|
445 |
def gradio_interface():
|
446 |
def update_visibility(mode):
|
447 |
if mode == "definition tree generation":
|
448 |
-
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True),
|
|
|
449 |
else:
|
450 |
-
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
with gr.Blocks() as demo:
|
453 |
gr.Markdown("# Gemma-2B SAE Feature Explorer")
|
454 |
|
455 |
with gr.Row():
|
456 |
with gr.Column(scale=2):
|
457 |
-
selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"],
|
|
|
458 |
feature_number = gr.Number(label="Select feature number", minimum=0, maximum=16383, value=0)
|
459 |
|
460 |
-
mode = gr.Radio(
|
461 |
-
|
462 |
-
label="Select mode",
|
463 |
-
value="cosine distance token lists"
|
464 |
-
)
|
465 |
|
466 |
weight_type = gr.Radio(["encoder", "decoder"], label="Select weight type for feature vector construction", value="encoder")
|
467 |
use_token_centroid = gr.Checkbox(label="Use token centroid offset", value=True)
|
@@ -485,16 +497,12 @@ def gradio_interface():
|
|
485 |
|
486 |
tree_data_state = gr.State()
|
487 |
|
488 |
-
inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor,
|
|
|
489 |
|
490 |
-
generate_btn.click(
|
491 |
-
update_output,
|
492 |
-
inputs=inputs,
|
493 |
-
outputs=[output_text, output_image], # Now the text output will be displayed in the Textbox
|
494 |
-
show_progress="full"
|
495 |
-
)
|
496 |
|
497 |
-
# other
|
498 |
|
499 |
return demo
|
500 |
|
@@ -502,7 +510,6 @@ def gradio_interface():
|
|
502 |
|
503 |
|
504 |
|
505 |
-
|
506 |
if __name__ == "__main__":
|
507 |
try:
|
508 |
logger.info("Starting application initialization...")
|
|
|
187 |
new_embedding = new_embedding.to(model.get_input_embeddings().weight.device)
|
188 |
model.get_input_embeddings().weight.data[token_id] = new_embedding
|
189 |
|
190 |
+
@spaces.GPU
|
191 |
+
def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, progress=gr.Progress()):
|
192 |
+
# Call process_input to generate the output
|
193 |
+
return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False, progress=progress)
|
194 |
+
|
195 |
def produce_next_token_ids(input_ids, model, topk, sub_token_id):
|
196 |
input_ids = input_ids.to(model.device)
|
197 |
with torch.no_grad():
|
|
|
450 |
def gradio_interface():
|
451 |
def update_visibility(mode):
|
452 |
if mode == "definition tree generation":
|
453 |
+
return (gr.update(visible=True), gr.update(visible=True), gr.update(visible=True),
|
454 |
+
gr.update(visible=False), gr.update(visible=False))
|
455 |
else:
|
456 |
+
return (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
|
457 |
+
gr.update(visible=True), gr.update(visible=True))
|
458 |
+
|
459 |
+
@spaces.GPU
|
460 |
+
def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor,
|
461 |
+
use_pca, pca_weight, num_exp, denom_exp, mode, progress=gr.Progress()):
|
462 |
+
# Call process_input to generate the output
|
463 |
+
return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor,
|
464 |
+
use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False, progress=progress)
|
465 |
|
466 |
with gr.Blocks() as demo:
|
467 |
gr.Markdown("# Gemma-2B SAE Feature Explorer")
|
468 |
|
469 |
with gr.Row():
|
470 |
with gr.Column(scale=2):
|
471 |
+
selected_sae = gr.Dropdown(choices=["Gemma-2B layer 0", "Gemma-2B layer 6", "Gemma-2B layer 10", "Gemma-2B layer 12"],
|
472 |
+
label="Select SAE")
|
473 |
feature_number = gr.Number(label="Select feature number", minimum=0, maximum=16383, value=0)
|
474 |
|
475 |
+
mode = gr.Radio(choices=["cosine distance token lists", "definition tree generation"],
|
476 |
+
label="Select mode", value="cosine distance token lists")
|
|
|
|
|
|
|
477 |
|
478 |
weight_type = gr.Radio(["encoder", "decoder"], label="Select weight type for feature vector construction", value="encoder")
|
479 |
use_token_centroid = gr.Checkbox(label="Use token centroid offset", value=True)
|
|
|
497 |
|
498 |
tree_data_state = gr.State()
|
499 |
|
500 |
+
inputs = [selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor,
|
501 |
+
use_pca, pca_weight, num_exp, denom_exp, mode]
|
502 |
|
503 |
+
generate_btn.click(update_output, inputs=inputs, outputs=[output_text, output_image], show_progress="full")
|
|
|
|
|
|
|
|
|
|
|
504 |
|
505 |
+
# Add other button functionality as needed...
|
506 |
|
507 |
return demo
|
508 |
|
|
|
510 |
|
511 |
|
512 |
|
|
|
513 |
if __name__ == "__main__":
|
514 |
try:
|
515 |
logger.info("Starting application initialization...")
|