import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM from typing import List, Dict, Tuple def cot_decode_speculative( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, messages: List[Dict[str, str]], k: int = 3, max_new_tokens: int = 512 ) -> Tuple[str, float]: """ Generates text using a speculative decoding approach with confidence and entropy-based metrics. This function implements a Chain-of-Thought (CoT) decoding strategy that explores multiple potential next tokens and selects the one leading to a generation with lower entropy and higher confidence. It incorporates top-p sampling and calculates a path score based on confidence, entropy difference, and generation length. Args: model: The pre-trained language model. tokenizer: The corresponding tokenizer. messages: A list of dictionaries, where each dictionary represents a message with "role" and "content" keys. k: The number of top tokens to consider for speculative decoding. max_new_tokens: The maximum number of tokens to generate. Returns: A tuple containing the generated text and the calculated path score. """ # Format the input based on tokenizer capabilities. Handles both chat template and standard formats. if hasattr(tokenizer, 'chat_template'): # Efficiently uses chat template if available. input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: # Fallback for tokenizers without chat template support. input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) input_text += "\nassistant:" input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda") # GPU usage specified # Handle missing pad_token_id, common in some models. if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id attention_mask = (input_ids != tokenizer.pad_token_id).long().to("cuda") with torch.no_grad(): # No gradient calculation needed for inference. outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) # Use caching for efficiency past_key_values = outputs.past_key_values # Store past key/values for faster generation first_token_logits = F.softmax(outputs.logits[0, -1, :], dim=-1) # Probabilities of the next token top_k_logits, top_k_indices = torch.topk(first_token_logits, k) # Get top k logits and indices cumulative_probs = torch.cumsum(top_k_logits, dim=0) # Calculate cumulative probabilities for top-p top_p_mask = cumulative_probs <= 0.9 # Apply top-p filtering (nucleus sampling) if not torch.any(top_p_mask): # Ensure at least one token is selected top_p_mask[0] = True top_k_logits, top_k_indices = top_k_logits[top_p_mask], top_k_indices[top_p_mask] # Filter based on top-p min_diff, best_idx, past_key_values = float('inf'), None, None # Initialize variables for best token selection new_attention_mask = torch.cat([attention_mask, torch.ones(1, 1).long().to("cuda")], dim=-1) # Extend attention mask # Speculative decoding: Evaluate top-k tokens for idx in top_k_indices: # Iterate through top-k candidate tokens. new_token = idx.unsqueeze(0).unsqueeze(0) # Prepare the token for generation new_tokens = torch.cat([input_ids, new_token], dim=-1) # Add the token to the input sequence with torch.no_grad(): output = model.generate( # Generate one token to evaluate entropy new_tokens, attention_mask=new_attention_mask, max_new_tokens=1, output_scores=True, output_attentions=True, # Needed for attention entropy calculation return_dict_in_generate=True, past_key_values=past_key_values ) all_attentions = output.attentions[0][-1] # Extract last layer's attention weights attn_probs = F.softmax(all_attentions[:, -1, :], dim=-1) # Calculate attention probabilities entropy = -torch.sum(attn_probs * torch.log2(attn_probs + 1e-12), dim=-1) # Calculate entropy avg_entropy, avg_varentropy = torch.mean(entropy), torch.var(entropy) # Compute mean and variance of entropy diff = avg_entropy * 0.8 + avg_varentropy * 0.2 # Combine entropy metrics (weighted average) if diff < min_diff: # Select token with lowest entropy difference min_diff, best_idx = diff, idx new_token = best_idx.unsqueeze(0).unsqueeze(0) # Prepare the chosen best token. new_tokens = torch.cat([input_ids, new_token], dim=-1) # Append the token to the input sequence. # Generate the full sequence with the chosen best first token. output = model.generate( new_tokens, attention_mask=new_attention_mask, max_new_tokens=max_new_tokens, output_scores=True, return_dict_in_generate=True, past_key_values=past_key_values ) answer_ids = output.sequences[0][len(input_ids[0]):] # Extract generated tokens answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True) # Decode to text sum_confidence = 0 for step in range(len(output.scores)): logits_step = output.scores[step][0] # Logits for current step probs_step = F.softmax(logits_step, dim=-1) # Probabilities for current step top_probs, _ = torch.topk(probs_step, k=2, dim=-1) # Get top 2 probabilities confidence = top_probs[0] - top_probs[1] # Calculate confidence as difference between top 2 sum_confidence += confidence # Accumulate confidence over all steps avg_confidence = sum_confidence / len(answer_ids) # Average confidence per token avg_confidence = avg_confidence - 0.2 if avg_confidence >= 0.9 else avg_confidence # Adjust confidence if too high path_score = avg_confidence ** (min_diff) * (len(answer_ids) / max_new_tokens) # Calculate path score return answer_text, round(path_score.item() ** 0.33, 4) # Return generated text and score