Nu Appleblossom commited on
Commit
c9aa04e
1 Parent(s): 95a8028

next attempt at tree functionality 6

Browse files
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -194,12 +194,26 @@ def produce_next_token_ids(input_ids, model, topk, sub_token_id):
194
  with torch.no_grad():
195
  outputs = model(input_ids)
196
  logits = outputs.logits
 
 
 
 
 
197
  last_logits = logits[:, -1, :]
 
 
 
 
 
198
  last_logits[:, sub_token_id] = float('-inf')
199
  softmax_probs = torch.softmax(last_logits, dim=-1)
200
  top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
201
- return top_k_ids[0], top_k_probs[0]
202
 
 
 
 
 
 
203
 
204
  def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0, progress_callback=None):
205
  if depth >= max_depth or cumulative_prob < config.CUTOFF:
@@ -209,10 +223,10 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
209
  if progress_callback:
210
  progress_callback(f"Depth {depth}: {current_prompt} PROB: {cumulative_prob}\n")
211
 
212
- try:
213
- top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
214
- except Exception as e:
215
- logger.error(f"Error generating next token IDs at depth {depth}: {str(e)}")
216
  return
217
 
218
  for idx, token_id in enumerate(top_k_ids.tolist()):
@@ -247,7 +261,6 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
247
 
248
  build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob, progress_callback=progress_callback)
249
 
250
-
251
  def generate_definition_tree(base_prompt, embedding, model, tokenizer, config, progress_callback=None):
252
  results_dict = {"token": "", "cumulative_prob": 1, "children": []}
253
 
@@ -267,6 +280,7 @@ def generate_definition_tree(base_prompt, embedding, model, tokenizer, config, p
267
 
268
 
269
 
 
270
  def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
271
  current_max = max(current_max, node.get('cumulative_prob', 0))
272
  if node.get('cumulative_prob', 1) > 0:
 
194
  with torch.no_grad():
195
  outputs = model(input_ids)
196
  logits = outputs.logits
197
+
198
+ if logits.size(1) == 0: # Check if there are logits to process
199
+ logger.error("Logits are empty. Cannot produce next token IDs.")
200
+ return None, None
201
+
202
  last_logits = logits[:, -1, :]
203
+
204
+ if last_logits.size(0) == 0 or last_logits.size(1) == 0: # Check if last logits are valid
205
+ logger.error("Last logits are empty. Cannot produce next token IDs.")
206
+ return None, None
207
+
208
  last_logits[:, sub_token_id] = float('-inf')
209
  softmax_probs = torch.softmax(last_logits, dim=-1)
210
  top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
 
211
 
212
+ if top_k_ids.size(0) == 0 or top_k_probs.size(0) == 0: # Check if we successfully got top-k IDs and probabilities
213
+ logger.error("Top-k IDs or probabilities are empty. Cannot produce next token IDs.")
214
+ return None, None
215
+
216
+ return top_k_ids[0], top_k_probs[0]
217
 
218
  def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0, progress_callback=None):
219
  if depth >= max_depth or cumulative_prob < config.CUTOFF:
 
223
  if progress_callback:
224
  progress_callback(f"Depth {depth}: {current_prompt} PROB: {cumulative_prob}\n")
225
 
226
+ top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
227
+
228
+ if top_k_ids is None or top_k_probs is None: # Ensure that top_k_ids and top_k_probs are valid before proceeding
229
+ logger.error(f"Failed to generate next token IDs at depth {depth}.")
230
  return
231
 
232
  for idx, token_id in enumerate(top_k_ids.tolist()):
 
261
 
262
  build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob, progress_callback=progress_callback)
263
 
 
264
  def generate_definition_tree(base_prompt, embedding, model, tokenizer, config, progress_callback=None):
265
  results_dict = {"token": "", "cumulative_prob": 1, "children": []}
266
 
 
280
 
281
 
282
 
283
+
284
  def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
285
  current_max = max(current_max, node.get('cumulative_prob', 0))
286
  if node.get('cumulative_prob', 1) > 0: