Spaces:
Sleeping
Sleeping
Nu Appleblossom
commited on
Commit
•
7b3b0f0
1
Parent(s):
1d40a54
back to last promising version with treebuild crashlog
Browse files
app.py
CHANGED
@@ -4,7 +4,6 @@ import torch.nn.functional as F
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
from safetensors import safe_open
|
6 |
import os
|
7 |
-
import io
|
8 |
import requests
|
9 |
import json
|
10 |
import math
|
@@ -199,15 +198,14 @@ def produce_next_token_ids(input_ids, model, topk, sub_token_id):
|
|
199 |
top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
|
200 |
return top_k_ids[0], top_k_probs[0]
|
201 |
|
202 |
-
|
203 |
-
def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0, output_buffer=None):
|
204 |
if depth >= max_depth or cumulative_prob < config.CUTOFF:
|
205 |
return
|
206 |
|
207 |
current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
208 |
|
209 |
-
# Print the current
|
210 |
-
print(f"Depth {depth}: {current_prompt} PROB: {cumulative_prob
|
211 |
|
212 |
top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
|
213 |
|
@@ -224,6 +222,10 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
|
|
224 |
continue
|
225 |
|
226 |
token_str = tokenizer.decode([token_id], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
227 |
|
228 |
new_child = {
|
229 |
"token_id": token_id,
|
@@ -233,11 +235,10 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
|
|
233 |
}
|
234 |
data['children'].append(new_child)
|
235 |
|
236 |
-
build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob
|
237 |
|
238 |
def generate_definition_tree(base_prompt, embedding, model, tokenizer, config):
|
239 |
results_dict = {"token": "", "cumulative_prob": 1, "children": []}
|
240 |
-
output_buffer = io.StringIO()
|
241 |
|
242 |
# Reset the token embedding
|
243 |
token_embedding = torch.unsqueeze(embedding, dim=0).to(model.device)
|
@@ -248,14 +249,9 @@ def generate_definition_tree(base_prompt, embedding, model, tokenizer, config):
|
|
248 |
model.reset_cache()
|
249 |
|
250 |
input_ids = tokenizer.encode(base_prompt, return_tensors="pt").to(model.device)
|
251 |
-
build_def_tree(input_ids, results_dict, base_prompt, model, tokenizer, config
|
252 |
-
|
253 |
-
tree_output = output_buffer.getvalue()
|
254 |
-
output_buffer.close()
|
255 |
-
|
256 |
-
return results_dict, tree_output
|
257 |
-
|
258 |
|
|
|
259 |
|
260 |
|
261 |
def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
|
@@ -370,7 +366,7 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
|
|
370 |
if w_enc is None or w_dec is None:
|
371 |
error_message = f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection."
|
372 |
logger.error(error_message)
|
373 |
-
return error_message
|
374 |
w_enc_dict[selected_sae] = w_enc
|
375 |
w_dec_dict[selected_sae] = w_dec
|
376 |
else:
|
@@ -398,29 +394,30 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
|
|
398 |
# Generate the top 500 list
|
399 |
result = ", ".join([f"'{token}': {value:.4f}" for token, value in closest_tokens_with_values])
|
400 |
logger.info("Returning top 500 list")
|
401 |
-
return result
|
402 |
else:
|
403 |
# Generate the top 100 list
|
404 |
token_list = [token for token, _ in closest_tokens_with_values[:100]]
|
405 |
result = f"100 tokens whose embeddings produce the smallest ratio (cos distance to feature vector)^m/(cos distance to token centroid)^n:\n\n"
|
406 |
-
|
|
|
407 |
logger.info("Returning top 100 tokens")
|
408 |
-
return result
|
409 |
|
410 |
elif mode == "definition tree generation":
|
411 |
logger.info("Generating definition tree")
|
412 |
-
tree_data
|
413 |
|
414 |
max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
|
415 |
tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight)
|
416 |
|
417 |
-
return
|
418 |
|
419 |
-
return "Mode not recognized or not implemented in this step."
|
420 |
|
421 |
except Exception as e:
|
422 |
logger.error(f"Error in process_input: {str(e)}")
|
423 |
-
return f"Error: {str(e)}"
|
424 |
finally:
|
425 |
del feature_vector
|
426 |
del token_centroid
|
@@ -428,7 +425,6 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
|
|
428 |
del pca_direction
|
429 |
torch.cuda.empty_cache()
|
430 |
|
431 |
-
|
432 |
|
433 |
def trim_tree(trim_cutoff, tree_data):
|
434 |
max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
|
@@ -450,17 +446,8 @@ def gradio_interface():
|
|
450 |
|
451 |
@spaces.GPU
|
452 |
def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, progress=gr.Progress()):
|
453 |
-
|
454 |
-
|
455 |
-
if mode == "definition tree generation":
|
456 |
-
for item in result:
|
457 |
-
if isinstance(item, tuple):
|
458 |
-
yield item[0], item[1], None # tree_text, tree_data, None
|
459 |
-
else:
|
460 |
-
yield item, None, None # Intermediate updates
|
461 |
-
else:
|
462 |
-
# For cosine distance token lists, result is not a generator
|
463 |
-
yield result, None, None
|
464 |
|
465 |
@spaces.GPU
|
466 |
def generate_top_500(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
|
@@ -475,7 +462,7 @@ def gradio_interface():
|
|
475 |
return trimmed_tree_image
|
476 |
|
477 |
with gr.Blocks() as demo:
|
478 |
-
gr.Markdown("# Gemma-2B SAE Feature Explorer")
|
479 |
|
480 |
with gr.Row():
|
481 |
with gr.Column(scale=2):
|
@@ -516,7 +503,7 @@ def gradio_interface():
|
|
516 |
generate_btn.click(
|
517 |
update_output,
|
518 |
inputs=inputs,
|
519 |
-
outputs=[output_text, output_image
|
520 |
show_progress="full"
|
521 |
).then(lambda: gr.update(visible=False, value=""), None, [output_500_text])
|
522 |
|
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
from safetensors import safe_open
|
6 |
import os
|
|
|
7 |
import requests
|
8 |
import json
|
9 |
import math
|
|
|
198 |
top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
|
199 |
return top_k_ids[0], top_k_probs[0]
|
200 |
|
201 |
+
def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0):
|
|
|
202 |
if depth >= max_depth or cumulative_prob < config.CUTOFF:
|
203 |
return
|
204 |
|
205 |
current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
206 |
|
207 |
+
# Print the current cumulative definition being built
|
208 |
+
print("\n" + f"Depth {depth}: {current_prompt} PROB: {cumulative_prob}")
|
209 |
|
210 |
top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
|
211 |
|
|
|
222 |
continue
|
223 |
|
224 |
token_str = tokenizer.decode([token_id], skip_special_tokens=True)
|
225 |
+
|
226 |
+
# Add the token to the current definition being built and print it
|
227 |
+
updated_prompt = f"{current_prompt} {token_str}"
|
228 |
+
print(f"Token: {token_str}, Updated Definition: {updated_prompt}, Cumulative Probability: {new_cumulative_prob}")
|
229 |
|
230 |
new_child = {
|
231 |
"token_id": token_id,
|
|
|
235 |
}
|
236 |
data['children'].append(new_child)
|
237 |
|
238 |
+
build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob)
|
239 |
|
240 |
def generate_definition_tree(base_prompt, embedding, model, tokenizer, config):
|
241 |
results_dict = {"token": "", "cumulative_prob": 1, "children": []}
|
|
|
242 |
|
243 |
# Reset the token embedding
|
244 |
token_embedding = torch.unsqueeze(embedding, dim=0).to(model.device)
|
|
|
249 |
model.reset_cache()
|
250 |
|
251 |
input_ids = tokenizer.encode(base_prompt, return_tensors="pt").to(model.device)
|
252 |
+
build_def_tree(input_ids, results_dict, base_prompt, model, tokenizer, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
return results_dict
|
255 |
|
256 |
|
257 |
def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
|
|
|
366 |
if w_enc is None or w_dec is None:
|
367 |
error_message = f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection."
|
368 |
logger.error(error_message)
|
369 |
+
return error_message, None
|
370 |
w_enc_dict[selected_sae] = w_enc
|
371 |
w_dec_dict[selected_sae] = w_dec
|
372 |
else:
|
|
|
394 |
# Generate the top 500 list
|
395 |
result = ", ".join([f"'{token}': {value:.4f}" for token, value in closest_tokens_with_values])
|
396 |
logger.info("Returning top 500 list")
|
397 |
+
return result, None
|
398 |
else:
|
399 |
# Generate the top 100 list
|
400 |
token_list = [token for token, _ in closest_tokens_with_values[:100]]
|
401 |
result = f"100 tokens whose embeddings produce the smallest ratio (cos distance to feature vector)^m/(cos distance to token centroid)^n:\n\n"
|
402 |
+
|
403 |
+
result += f"[{', '.join(repr(token) for token in token_list)}]\n"
|
404 |
logger.info("Returning top 100 tokens")
|
405 |
+
return result, None
|
406 |
|
407 |
elif mode == "definition tree generation":
|
408 |
logger.info("Generating definition tree")
|
409 |
+
tree_data = generate_definition_tree("definition tree", feature_vector, model, tokenizer, config)
|
410 |
|
411 |
max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
|
412 |
tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight)
|
413 |
|
414 |
+
return None, tree_image
|
415 |
|
416 |
+
return "Mode not recognized or not implemented in this step.", None
|
417 |
|
418 |
except Exception as e:
|
419 |
logger.error(f"Error in process_input: {str(e)}")
|
420 |
+
return f"Error: {str(e)}", None
|
421 |
finally:
|
422 |
del feature_vector
|
423 |
del token_centroid
|
|
|
425 |
del pca_direction
|
426 |
torch.cuda.empty_cache()
|
427 |
|
|
|
428 |
|
429 |
def trim_tree(trim_cutoff, tree_data):
|
430 |
max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
|
|
|
446 |
|
447 |
@spaces.GPU
|
448 |
def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, progress=gr.Progress()):
|
449 |
+
# Call process_input without generating the top 500 list initially
|
450 |
+
return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False, progress=progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
|
452 |
@spaces.GPU
|
453 |
def generate_top_500(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
|
|
|
462 |
return trimmed_tree_image
|
463 |
|
464 |
with gr.Blocks() as demo:
|
465 |
+
gr.Markdown("# Gemma-2B SAE Feature Explorer (back2crashlogs)")
|
466 |
|
467 |
with gr.Row():
|
468 |
with gr.Column(scale=2):
|
|
|
503 |
generate_btn.click(
|
504 |
update_output,
|
505 |
inputs=inputs,
|
506 |
+
outputs=[output_text, output_image],
|
507 |
show_progress="full"
|
508 |
).then(lambda: gr.update(visible=False, value=""), None, [output_500_text])
|
509 |
|