Spaces:
Sleeping
Sleeping
Nu Appleblossom
commited on
Commit
•
c9aa04e
1
Parent(s):
95a8028
next attempt at tree functionality 6
Browse files
app.py
CHANGED
@@ -194,12 +194,26 @@ def produce_next_token_ids(input_ids, model, topk, sub_token_id):
|
|
194 |
with torch.no_grad():
|
195 |
outputs = model(input_ids)
|
196 |
logits = outputs.logits
|
|
|
|
|
|
|
|
|
|
|
197 |
last_logits = logits[:, -1, :]
|
|
|
|
|
|
|
|
|
|
|
198 |
last_logits[:, sub_token_id] = float('-inf')
|
199 |
softmax_probs = torch.softmax(last_logits, dim=-1)
|
200 |
top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
|
201 |
-
return top_k_ids[0], top_k_probs[0]
|
202 |
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0, progress_callback=None):
|
205 |
if depth >= max_depth or cumulative_prob < config.CUTOFF:
|
@@ -209,10 +223,10 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
|
|
209 |
if progress_callback:
|
210 |
progress_callback(f"Depth {depth}: {current_prompt} PROB: {cumulative_prob}\n")
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
logger.error(f"
|
216 |
return
|
217 |
|
218 |
for idx, token_id in enumerate(top_k_ids.tolist()):
|
@@ -247,7 +261,6 @@ def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth
|
|
247 |
|
248 |
build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob, progress_callback=progress_callback)
|
249 |
|
250 |
-
|
251 |
def generate_definition_tree(base_prompt, embedding, model, tokenizer, config, progress_callback=None):
|
252 |
results_dict = {"token": "", "cumulative_prob": 1, "children": []}
|
253 |
|
@@ -267,6 +280,7 @@ def generate_definition_tree(base_prompt, embedding, model, tokenizer, config, p
|
|
267 |
|
268 |
|
269 |
|
|
|
270 |
def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
|
271 |
current_max = max(current_max, node.get('cumulative_prob', 0))
|
272 |
if node.get('cumulative_prob', 1) > 0:
|
|
|
194 |
with torch.no_grad():
|
195 |
outputs = model(input_ids)
|
196 |
logits = outputs.logits
|
197 |
+
|
198 |
+
if logits.size(1) == 0: # Check if there are logits to process
|
199 |
+
logger.error("Logits are empty. Cannot produce next token IDs.")
|
200 |
+
return None, None
|
201 |
+
|
202 |
last_logits = logits[:, -1, :]
|
203 |
+
|
204 |
+
if last_logits.size(0) == 0 or last_logits.size(1) == 0: # Check if last logits are valid
|
205 |
+
logger.error("Last logits are empty. Cannot produce next token IDs.")
|
206 |
+
return None, None
|
207 |
+
|
208 |
last_logits[:, sub_token_id] = float('-inf')
|
209 |
softmax_probs = torch.softmax(last_logits, dim=-1)
|
210 |
top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk, dim=-1)
|
|
|
211 |
|
212 |
+
if top_k_ids.size(0) == 0 or top_k_probs.size(0) == 0: # Check if we successfully got top-k IDs and probabilities
|
213 |
+
logger.error("Top-k IDs or probabilities are empty. Cannot produce next token IDs.")
|
214 |
+
return None, None
|
215 |
+
|
216 |
+
return top_k_ids[0], top_k_probs[0]
|
217 |
|
218 |
def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, depth=0, max_depth=25, cumulative_prob=1.0, progress_callback=None):
|
219 |
if depth >= max_depth or cumulative_prob < config.CUTOFF:
|
|
|
223 |
if progress_callback:
|
224 |
progress_callback(f"Depth {depth}: {current_prompt} PROB: {cumulative_prob}\n")
|
225 |
|
226 |
+
top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)
|
227 |
+
|
228 |
+
if top_k_ids is None or top_k_probs is None: # Ensure that top_k_ids and top_k_probs are valid before proceeding
|
229 |
+
logger.error(f"Failed to generate next token IDs at depth {depth}.")
|
230 |
return
|
231 |
|
232 |
for idx, token_id in enumerate(top_k_ids.tolist()):
|
|
|
261 |
|
262 |
build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob, progress_callback=progress_callback)
|
263 |
|
|
|
264 |
def generate_definition_tree(base_prompt, embedding, model, tokenizer, config, progress_callback=None):
|
265 |
results_dict = {"token": "", "cumulative_prob": 1, "children": []}
|
266 |
|
|
|
280 |
|
281 |
|
282 |
|
283 |
+
|
284 |
def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
|
285 |
current_max = max(current_max, node.get('cumulative_prob', 0))
|
286 |
if node.get('cumulative_prob', 1) > 0:
|