Nu Appleblossom commited on
Commit
a5bba05
1 Parent(s): 6d779d0

back to last promising version with treebuild crashlog, revising whole tree gen thing with Claude 5

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -19,6 +19,14 @@ from graphviz import Digraph
19
  from PIL import Image, ImageDraw, ImageFont
20
  from io import BytesIO
21
  import functools
 
 
 
 
 
 
 
 
22
 
23
  # Load environment variables
24
  load_dotenv()
@@ -203,13 +211,16 @@ def produce_next_token_ids(input_ids, model, topk, sub_token_id):
203
 
204
 
205
  def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0):
206
- logger.info(f"build_def_tree called with depth={depth}, cumulative_prob={cumulative_prob:.4f}")
207
  if depth >= max_depth or cumulative_prob < config.CUTOFF:
208
  return
209
 
210
  current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
211
 
212
- yield f"\nDepth {depth}: {current_prompt} PROB: {cumulative_prob:.4f}"
 
 
 
 
213
 
214
  top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
215
 
@@ -238,6 +249,7 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
238
  yield from build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob)
239
 
240
 
 
241
  def generate_definition_tree(base_prompt, embedding, model, tokenizer, config):
242
  logger.info(f"Starting generate_definition_tree with base_prompt: {base_prompt}")
243
  results_dict = {"token": "", "cumulative_prob": 1, "children": []}
@@ -421,10 +433,8 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
421
  for item in tree_generator:
422
  if isinstance(item, str):
423
  log_output.append(item)
424
- logger.info(f"Tree generation step: {item}") # Log each step
425
  else:
426
  tree_data = item
427
- logger.info("Received tree data")
428
 
429
  # Join the log output into a single string
430
  log_text = "\n".join(log_output)
@@ -440,6 +450,7 @@ def process_input(selected_sae, feature_number, weight_type, use_token_centroid,
440
  logger.error("Failed to generate tree data")
441
  return "Error: Failed to generate tree data.", None
442
 
 
443
  return "Mode not recognized or not implemented in this step.", None
444
 
445
  except Exception as e:
 
19
  from PIL import Image, ImageDraw, ImageFont
20
  from io import BytesIO
21
  import functools
22
+ import logging
23
+
24
+ # Set up custom logger
25
+ custom_logger = logging.getLogger("custom_logger")
26
+ custom_logger.setLevel(logging.INFO)
27
+ custom_handler = logging.StreamHandler()
28
+ custom_handler.setFormatter(logging.Formatter('%(message)s'))
29
+ custom_logger.addHandler(custom_handler)
30
 
31
  # Load environment variables
32
  load_dotenv()
 
211
 
212
 
213
  def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0):
 
214
  if depth >= max_depth or cumulative_prob < config.CUTOFF:
215
  return
216
 
217
  current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
218
 
219
+ # Extract only the part that extends the base prompt
220
+ extended_prompt = current_prompt[len(base_prompt):].strip()
221
+
222
+ # Use a custom logger to avoid the "INFO:__main__:" prefix
223
+ custom_logger.info(f"Depth {depth}: {extended_prompt} PROB: {cumulative_prob:.4f}")
224
 
225
  top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
226
 
 
249
  yield from build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob)
250
 
251
 
252
+
253
  def generate_definition_tree(base_prompt, embedding, model, tokenizer, config):
254
  logger.info(f"Starting generate_definition_tree with base_prompt: {base_prompt}")
255
  results_dict = {"token": "", "cumulative_prob": 1, "children": []}
 
433
  for item in tree_generator:
434
  if isinstance(item, str):
435
  log_output.append(item)
 
436
  else:
437
  tree_data = item
 
438
 
439
  # Join the log output into a single string
440
  log_text = "\n".join(log_output)
 
450
  logger.error("Failed to generate tree data")
451
  return "Error: Failed to generate tree data.", None
452
 
453
+
454
  return "Mode not recognized or not implemented in this step.", None
455
 
456
  except Exception as e: