Nu Appleblossom commited on
Commit
97dc5e7
1 Parent(s): c9aa04e

next attempt at tree functionality 7

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -195,22 +195,31 @@ def produce_next_token_ids(input_ids, model, topk, sub_token_id):
195
  outputs = model(input_ids)
196
  logits = outputs.logits
197
 
198
- if logits.size(1) == 0: # Check if there are logits to process
199
- logger.error("Logits are empty. Cannot produce next token IDs.")
 
200
  return None, None
201
 
202
  last_logits = logits[:, -1, :]
203
 
204
- if last_logits.size(0) == 0 or last_logits.size(1) == 0: # Check if last logits are valid
205
- logger.error("Last logits are empty. Cannot produce next token IDs.")
 
206
  return None, None
207
 
208
  last_logits[:, sub_token_id] = float('-inf')
209
  softmax_probs = torch.softmax(last_logits, dim=-1)
 
 
 
 
 
 
210
  top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
211
 
212
- if top_k_ids.size(0) == 0 or top_k_probs.size(0) == 0: # Check if we successfully got top-k IDs and probabilities
213
- logger.error("Top-k IDs or probabilities are empty. Cannot produce next token IDs.")
 
214
  return None, None
215
 
216
  return top_k_ids[0], top_k_probs[0]
@@ -225,7 +234,7 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
225
 
226
  top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
227
 
228
- if top_k_ids is None or top_k_probs is None: # Ensure that top_k_ids and top_k_probs are valid before proceeding
229
  logger.error(f"Failed to generate next token IDs at depth {depth}.")
230
  return
231
 
@@ -233,22 +242,18 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
233
  if token_id == config.SUB_TOKEN_ID:
234
  continue # Skip the substitute token to avoid circular definitions
235
 
236
- try:
237
- token_id_tensor = torch.tensor([token_id], dtype=torch.long).to(model.device)
238
- new_input_ids = torch.cat([input_ids, token_id_tensor.view(1, 1)], dim=-1)
239
- except IndexError as e:
240
- logger.error(f"IndexError in processing token ID {token_id} at depth {depth}: {str(e)}")
241
- continue
242
 
243
  new_cumulative_prob = cumulative_prob * top_k_probs[idx].item()
244
 
245
  if new_cumulative_prob < config.CUTOFF:
246
  continue
247
 
248
- try:
249
- token_str = tokenizer.decode([token_id], skip_special_tokens=True)
250
- except Exception as e:
251
- logger.error(f"Error decoding token ID {token_id} at depth {depth}: {str(e)}")
252
  continue
253
 
254
  new_child = {
@@ -279,8 +284,6 @@ def generate_definition_tree(base_prompt, embedding, model, tokenizer, config, p
279
 
280
 
281
 
282
-
283
-
284
  def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
285
  current_max = max(current_max, node.get('cumulative_prob', 0))
286
  if node.get('cumulative_prob', 1) > 0:
 
195
  outputs = model(input_ids)
196
  logits = outputs.logits
197
 
198
+ # Check if logits is empty or invalid
199
+ if logits is None or logits.size(1) == 0:
200
+ logger.error("Logits are empty or invalid.")
201
  return None, None
202
 
203
  last_logits = logits[:, -1, :]
204
 
205
+ # Check if last_logits is empty or invalid
206
+ if last_logits.size(0) == 0 or last_logits.size(1) == 0:
207
+ logger.error("Last logits are empty or invalid.")
208
  return None, None
209
 
210
  last_logits[:, sub_token_id] = float('-inf')
211
  softmax_probs = torch.softmax(last_logits, dim=-1)
212
+
213
+ # Check if softmax_probs are empty or invalid
214
+ if softmax_probs is None or softmax_probs.size(0) == 0:
215
+ logger.error("Softmax probabilities are empty or invalid.")
216
+ return None, None
217
+
218
  top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
219
 
220
+ # Check if top_k_ids and top_k_probs are valid
221
+ if top_k_ids.size(0) == 0 or top_k_probs.size(0) == 0:
222
+ logger.error("Top-k IDs or probabilities are empty or invalid.")
223
  return None, None
224
 
225
  return top_k_ids[0], top_k_probs[0]
 
234
 
235
  top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
236
 
237
+ if top_k_ids is None or top_k_probs is None:
238
  logger.error(f"Failed to generate next token IDs at depth {depth}.")
239
  return
240
 
 
242
  if token_id == config.SUB_TOKEN_ID:
243
  continue # Skip the substitute token to avoid circular definitions
244
 
245
+ token_id_tensor = torch.tensor([token_id], dtype=torch.long).to(model.device)
246
+ new_input_ids = torch.cat([input_ids, token_id_tensor.view(1, 1)], dim=-1)
 
 
 
 
247
 
248
  new_cumulative_prob = cumulative_prob * top_k_probs[idx].item()
249
 
250
  if new_cumulative_prob < config.CUTOFF:
251
  continue
252
 
253
+ token_str = tokenizer.decode([token_id], skip_special_tokens=True)
254
+
255
+ if token_str is None or token_str == "":
256
+ logger.error(f"Token string is empty or invalid at depth {depth} for token ID {token_id}.")
257
  continue
258
 
259
  new_child = {
 
284
 
285
 
286
 
 
 
287
  def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
288
  current_max = max(current_max, node.get('cumulative_prob', 0))
289
  if node.get('cumulative_prob', 1) > 0: