|
import torch |
|
import torch.nn.functional as F |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from typing import List, Dict, Tuple |
|
|
|
def cot_decode_speculative( |
|
model: AutoModelForCausalLM, |
|
tokenizer: AutoTokenizer, |
|
messages: List[Dict[str, str]], |
|
k: int = 3, |
|
max_new_tokens: int = 512 |
|
) -> Tuple[str, float]: |
|
""" |
|
Generates text using a speculative decoding approach with confidence and entropy-based metrics. |
|
|
|
This function implements a Chain-of-Thought (CoT) decoding strategy that explores multiple |
|
potential next tokens and selects the one leading to a generation with lower entropy and |
|
higher confidence. It incorporates top-p sampling and calculates a path score based on confidence, |
|
entropy difference, and generation length. |
|
|
|
Args: |
|
model: The pre-trained language model. |
|
tokenizer: The corresponding tokenizer. |
|
messages: A list of dictionaries, where each dictionary represents a message with "role" and "content" keys. |
|
k: The number of top tokens to consider for speculative decoding. |
|
max_new_tokens: The maximum number of tokens to generate. |
|
|
|
Returns: |
|
A tuple containing the generated text and the calculated path score. |
|
""" |
|
|
|
|
|
if hasattr(tokenizer, 'chat_template'): |
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
else: |
|
input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) |
|
input_text += "\nassistant:" |
|
|
|
input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda") |
|
|
|
|
|
if tokenizer.pad_token_id is None: |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
attention_mask = (input_ids != tokenizer.pad_token_id).long().to("cuda") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) |
|
past_key_values = outputs.past_key_values |
|
first_token_logits = F.softmax(outputs.logits[0, -1, :], dim=-1) |
|
top_k_logits, top_k_indices = torch.topk(first_token_logits, k) |
|
cumulative_probs = torch.cumsum(top_k_logits, dim=0) |
|
top_p_mask = cumulative_probs <= 0.9 |
|
if not torch.any(top_p_mask): |
|
top_p_mask[0] = True |
|
top_k_logits, top_k_indices = top_k_logits[top_p_mask], top_k_indices[top_p_mask] |
|
|
|
|
|
min_diff, best_idx, past_key_values = float('inf'), None, None |
|
new_attention_mask = torch.cat([attention_mask, torch.ones(1, 1).long().to("cuda")], dim=-1) |
|
|
|
|
|
|
|
for idx in top_k_indices: |
|
new_token = idx.unsqueeze(0).unsqueeze(0) |
|
new_tokens = torch.cat([input_ids, new_token], dim=-1) |
|
with torch.no_grad(): |
|
output = model.generate( |
|
new_tokens, |
|
attention_mask=new_attention_mask, |
|
max_new_tokens=1, |
|
output_scores=True, |
|
output_attentions=True, |
|
return_dict_in_generate=True, |
|
past_key_values=past_key_values |
|
) |
|
all_attentions = output.attentions[0][-1] |
|
attn_probs = F.softmax(all_attentions[:, -1, :], dim=-1) |
|
entropy = -torch.sum(attn_probs * torch.log2(attn_probs + 1e-12), dim=-1) |
|
avg_entropy, avg_varentropy = torch.mean(entropy), torch.var(entropy) |
|
diff = avg_entropy * 0.8 + avg_varentropy * 0.2 |
|
|
|
if diff < min_diff: |
|
min_diff, best_idx = diff, idx |
|
|
|
new_token = best_idx.unsqueeze(0).unsqueeze(0) |
|
new_tokens = torch.cat([input_ids, new_token], dim=-1) |
|
|
|
|
|
output = model.generate( |
|
new_tokens, |
|
attention_mask=new_attention_mask, |
|
max_new_tokens=max_new_tokens, |
|
output_scores=True, |
|
return_dict_in_generate=True, |
|
past_key_values=past_key_values |
|
) |
|
|
|
answer_ids = output.sequences[0][len(input_ids[0]):] |
|
answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True) |
|
|
|
sum_confidence = 0 |
|
for step in range(len(output.scores)): |
|
logits_step = output.scores[step][0] |
|
probs_step = F.softmax(logits_step, dim=-1) |
|
top_probs, _ = torch.topk(probs_step, k=2, dim=-1) |
|
confidence = top_probs[0] - top_probs[1] |
|
sum_confidence += confidence |
|
avg_confidence = sum_confidence / len(answer_ids) |
|
avg_confidence = avg_confidence - 0.2 if avg_confidence >= 0.9 else avg_confidence |
|
|
|
|
|
path_score = avg_confidence ** (min_diff) * (len(answer_ids) / max_new_tokens) |
|
return answer_text, round(path_score.item() ** 0.33, 4) |