Nu Appleblossom commited on
Commit
95a8028
·
1 Parent(s): e7b948c

next attempt at tree functionality 5

Browse files
Files changed (1) hide show
  1. app.py +23 -7
app.py CHANGED
@@ -208,21 +208,34 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
208
  current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
209
  if progress_callback:
210
  progress_callback(f"Depth {depth}: {current_prompt} PROB: {cumulative_prob}\n")
211
- top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
 
 
 
 
 
212
 
213
  for idx, token_id in enumerate(top_k_ids.tolist()):
214
  if token_id == config.SUB_TOKEN_ID:
215
- continue
216
 
217
- token_id_tensor = torch.tensor([token_id], dtype=torch.long).to(model.device)
218
- new_input_ids = torch.cat([input_ids, token_id_tensor.view(1, 1)], dim=-1)
 
 
 
 
219
 
220
  new_cumulative_prob = cumulative_prob * top_k_probs[idx].item()
221
 
222
  if new_cumulative_prob < config.CUTOFF:
223
  continue
224
 
225
- token_str = tokenizer.decode([token_id], skip_special_tokens=True)
 
 
 
 
226
 
227
  new_child = {
228
  "token_id": token_id,
@@ -235,14 +248,14 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
235
  build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob, progress_callback=progress_callback)
236
 
237
 
238
-
239
-
240
  def generate_definition_tree(base_prompt, embedding, model, tokenizer, config, progress_callback=None):
241
  results_dict = {"token": "", "cumulative_prob": 1, "children": []}
242
 
 
243
  token_embedding = torch.unsqueeze(embedding, dim=0).to(model.device)
244
  update_token_embedding(model, config.SUB_TOKEN_ID, token_embedding)
245
 
 
246
  if hasattr(model, 'reset_cache'):
247
  model.reset_cache()
248
 
@@ -251,6 +264,9 @@ def generate_definition_tree(base_prompt, embedding, model, tokenizer, config, p
251
 
252
  return results_dict
253
 
 
 
 
254
  def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
255
  current_max = max(current_max, node.get('cumulative_prob', 0))
256
  if node.get('cumulative_prob', 1) > 0:
 
208
  current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
209
  if progress_callback:
210
  progress_callback(f"Depth {depth}: {current_prompt} PROB: {cumulative_prob}\n")
211
+
212
+ try:
213
+ top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
214
+ except Exception as e:
215
+ logger.error(f"Error generating next token IDs at depth {depth}: {str(e)}")
216
+ return
217
 
218
  for idx, token_id in enumerate(top_k_ids.tolist()):
219
  if token_id == config.SUB_TOKEN_ID:
220
+ continue # Skip the substitute token to avoid circular definitions
221
 
222
+ try:
223
+ token_id_tensor = torch.tensor([token_id], dtype=torch.long).to(model.device)
224
+ new_input_ids = torch.cat([input_ids, token_id_tensor.view(1, 1)], dim=-1)
225
+ except IndexError as e:
226
+ logger.error(f"IndexError in processing token ID {token_id} at depth {depth}: {str(e)}")
227
+ continue
228
 
229
  new_cumulative_prob = cumulative_prob * top_k_probs[idx].item()
230
 
231
  if new_cumulative_prob < config.CUTOFF:
232
  continue
233
 
234
+ try:
235
+ token_str = tokenizer.decode([token_id], skip_special_tokens=True)
236
+ except Exception as e:
237
+ logger.error(f"Error decoding token ID {token_id} at depth {depth}: {str(e)}")
238
+ continue
239
 
240
  new_child = {
241
  "token_id": token_id,
 
248
  build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob, progress_callback=progress_callback)
249
 
250
 
 
 
251
  def generate_definition_tree(base_prompt, embedding, model, tokenizer, config, progress_callback=None):
252
  results_dict = {"token": "", "cumulative_prob": 1, "children": []}
253
 
254
+ # Reset the token embedding
255
  token_embedding = torch.unsqueeze(embedding, dim=0).to(model.device)
256
  update_token_embedding(model, config.SUB_TOKEN_ID, token_embedding)
257
 
258
+ # Clear the model's cache if it has one
259
  if hasattr(model, 'reset_cache'):
260
  model.reset_cache()
261
 
 
264
 
265
  return results_dict
266
 
267
+
268
+
269
+
270
  def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
271
  current_max = max(current_max, node.get('cumulative_prob', 0))
272
  if node.get('cumulative_prob', 1) > 0: