Nu Appleblossom commited on
Commit
7b3b0f0
1 Parent(s): 1d40a54

back to last promising version with treebuild crashlog

Browse files
Files changed (1) hide show
  1. app.py +23 -36
app.py CHANGED
@@ -4,7 +4,6 @@ import torch.nn.functional as F
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from safetensors import safe_open
6
  import os
7
- import io
8
  import requests
9
  import json
10
  import math
@@ -199,15 +198,14 @@ def produce_next_token_ids(input_ids, model, topk, sub_token_id):
199
  top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
200
  return top_k_ids[0], top_k_probs[0]
201
 
202
-
203
- def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0, output_buffer=None):
204
  if depth >= max_depth or cumulative_prob < config.CUTOFF:
205
  return
206
 
207
  current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
208
 
209
- # Print the current node information to the buffer
210
- print(f"Depth {depth}: {current_prompt} PROB: {cumulative_prob:.4f}", file=output_buffer)
211
 
212
  top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
213
 
@@ -224,6 +222,10 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
224
  continue
225
 
226
  token_str = tokenizer.decode([token_id], skip_special_tokens=True)
 
 
 
 
227
 
228
  new_child = {
229
  "token_id": token_id,
@@ -233,11 +235,10 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
233
  }
234
  data['children'].append(new_child)
235
 
236
- build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob, output_buffer=output_buffer)
237
 
238
  def generate_definition_tree(base_prompt, embedding, model, tokenizer, config):
239
  results_dict = {"token": "", "cumulative_prob": 1, "children": []}
240
- output_buffer = io.StringIO()
241
 
242
  # Reset the token embedding
243
  token_embedding = torch.unsqueeze(embedding, dim=0).to(model.device)
@@ -248,14 +249,9 @@ def generate_definition_tree(base_prompt, embedding, model, tokenizer, config):
248
  model.reset_cache()
249
 
250
  input_ids = tokenizer.encode(base_prompt, return_tensors="pt").to(model.device)
251
- build_def_tree(input_ids, results_dict, base_prompt, model, tokenizer, config, output_buffer=output_buffer)
252
-
253
- tree_output = output_buffer.getvalue()
254
- output_buffer.close()
255
-
256
- return results_dict, tree_output
257
-
258
 
 
259
 
260
 
261
  def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
@@ -370,7 +366,7 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
370
  if w_enc is None or w_dec is None:
371
  error_message = f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection."
372
  logger.error(error_message)
373
- return error_message
374
  w_enc_dict[selected_sae] = w_enc
375
  w_dec_dict[selected_sae] = w_dec
376
  else:
@@ -398,29 +394,30 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
398
  # Generate the top 500 list
399
  result = ", ".join([f"'{token}': {value:.4f}" for token, value in closest_tokens_with_values])
400
  logger.info("Returning top 500 list")
401
- return result
402
  else:
403
  # Generate the top 100 list
404
  token_list = [token for token, _ in closest_tokens_with_values[:100]]
405
  result = f"100 tokens whose embeddings produce the smallest ratio (cos distance to feature vector)^m/(cos distance to token centroid)^n:\n\n"
406
- result += f"[{', '.join(repr(token) for token in token_list)}]"
 
407
  logger.info("Returning top 100 tokens")
408
- return result
409
 
410
  elif mode == "definition tree generation":
411
  logger.info("Generating definition tree")
412
- tree_data, tree_output = generate_definition_tree("definition tree", feature_vector, model, tokenizer, config)
413
 
414
  max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
415
  tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight)
416
 
417
- return tree_output, tree_image
418
 
419
- return "Mode not recognized or not implemented in this step."
420
 
421
  except Exception as e:
422
  logger.error(f"Error in process_input: {str(e)}")
423
- return f"Error: {str(e)}"
424
  finally:
425
  del feature_vector
426
  del token_centroid
@@ -428,7 +425,6 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
428
  del pca_direction
429
  torch.cuda.empty_cache()
430
 
431
-
432
 
433
  def trim_tree(trim_cutoff, tree_data):
434
  max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
@@ -450,17 +446,8 @@ def gradio_interface():
450
 
451
  @spaces.GPU
452
  def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, progress=gr.Progress()):
453
- result = process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False, progress=progress)
454
-
455
- if mode == "definition tree generation":
456
- for item in result:
457
- if isinstance(item, tuple):
458
- yield item[0], item[1], None # tree_text, tree_data, None
459
- else:
460
- yield item, None, None # Intermediate updates
461
- else:
462
- # For cosine distance token lists, result is not a generator
463
- yield result, None, None
464
 
465
  @spaces.GPU
466
  def generate_top_500(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
@@ -475,7 +462,7 @@ def gradio_interface():
475
  return trimmed_tree_image
476
 
477
  with gr.Blocks() as demo:
478
- gr.Markdown("# Gemma-2B SAE Feature Explorer")
479
 
480
  with gr.Row():
481
  with gr.Column(scale=2):
@@ -516,7 +503,7 @@ def gradio_interface():
516
  generate_btn.click(
517
  update_output,
518
  inputs=inputs,
519
- outputs=[output_text, output_image, tree_data_state],
520
  show_progress="full"
521
  ).then(lambda: gr.update(visible=False, value=""), None, [output_500_text])
522
 
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from safetensors import safe_open
6
  import os
 
7
  import requests
8
  import json
9
  import math
 
198
  top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
199
  return top_k_ids[0], top_k_probs[0]
200
 
201
+ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0):
 
202
  if depth >= max_depth or cumulative_prob < config.CUTOFF:
203
  return
204
 
205
  current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
206
 
207
+ # Print the current cumulative definition being built
208
+ print("\n" + f"Depth {depth}: {current_prompt} PROB: {cumulative_prob}")
209
 
210
  top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
211
 
 
222
  continue
223
 
224
  token_str = tokenizer.decode([token_id], skip_special_tokens=True)
225
+
226
+ # Add the token to the current definition being built and print it
227
+ updated_prompt = f"{current_prompt} {token_str}"
228
+ print(f"Token: {token_str}, Updated Definition: {updated_prompt}, Cumulative Probability: {new_cumulative_prob}")
229
 
230
  new_child = {
231
  "token_id": token_id,
 
235
  }
236
  data['children'].append(new_child)
237
 
238
+ build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob)
239
 
240
  def generate_definition_tree(base_prompt, embedding, model, tokenizer, config):
241
  results_dict = {"token": "", "cumulative_prob": 1, "children": []}
 
242
 
243
  # Reset the token embedding
244
  token_embedding = torch.unsqueeze(embedding, dim=0).to(model.device)
 
249
  model.reset_cache()
250
 
251
  input_ids = tokenizer.encode(base_prompt, return_tensors="pt").to(model.device)
252
+ build_def_tree(input_ids, results_dict, base_prompt, model, tokenizer, config)
 
 
 
 
 
 
253
 
254
+ return results_dict
255
 
256
 
257
  def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
 
366
  if w_enc is None or w_dec is None:
367
  error_message = f"Failed to load SAE weights for {selected_sae}. Please try a different SAE or check your connection."
368
  logger.error(error_message)
369
+ return error_message, None
370
  w_enc_dict[selected_sae] = w_enc
371
  w_dec_dict[selected_sae] = w_dec
372
  else:
 
394
  # Generate the top 500 list
395
  result = ", ".join([f"'{token}': {value:.4f}" for token, value in closest_tokens_with_values])
396
  logger.info("Returning top 500 list")
397
+ return result, None
398
  else:
399
  # Generate the top 100 list
400
  token_list = [token for token, _ in closest_tokens_with_values[:100]]
401
  result = f"100 tokens whose embeddings produce the smallest ratio (cos distance to feature vector)^m/(cos distance to token centroid)^n:\n\n"
402
+
403
+ result += f"[{', '.join(repr(token) for token in token_list)}]\n"
404
  logger.info("Returning top 100 tokens")
405
+ return result, None
406
 
407
  elif mode == "definition tree generation":
408
  logger.info("Generating definition tree")
409
+ tree_data = generate_definition_tree("definition tree", feature_vector, model, tokenizer, config)
410
 
411
  max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
412
  tree_image = create_tree_diagram(tree_data, config, max_weight, min_weight)
413
 
414
+ return None, tree_image
415
 
416
+ return "Mode not recognized or not implemented in this step.", None
417
 
418
  except Exception as e:
419
  logger.error(f"Error in process_input: {str(e)}")
420
+ return f"Error: {str(e)}", None
421
  finally:
422
  del feature_vector
423
  del token_centroid
 
425
  del pca_direction
426
  torch.cuda.empty_cache()
427
 
 
428
 
429
  def trim_tree(trim_cutoff, tree_data):
430
  max_weight, min_weight = find_max_min_cumulative_weight(tree_data)
 
446
 
447
  @spaces.GPU
448
  def update_output(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, progress=gr.Progress()):
449
+ # Call process_input without generating the top 500 list initially
450
+ return process_input(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode, top_500=False, progress=progress)
 
 
 
 
 
 
 
 
 
451
 
452
  @spaces.GPU
453
  def generate_top_500(selected_sae, feature_number, weight_type, use_token_centroid, scaling_factor, use_pca, pca_weight, num_exp, denom_exp, mode):
 
462
  return trimmed_tree_image
463
 
464
  with gr.Blocks() as demo:
465
+ gr.Markdown("# Gemma-2B SAE Feature Explorer (back2crashlogs)")
466
 
467
  with gr.Row():
468
  with gr.Column(scale=2):
 
503
  generate_btn.click(
504
  update_output,
505
  inputs=inputs,
506
+ outputs=[output_text, output_image],
507
  show_progress="full"
508
  ).then(lambda: gr.update(visible=False, value=""), None, [output_500_text])
509