from collections import deque from typing import Any, Optional, Union import numpy as np import torch import torch.nn.functional as F from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from transformers.generation.logits_process import ( TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, ) from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput def generate( model: Any, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, generation_config: Optional[GenerationConfig] = None, synced_gpus: bool = False, streamer: Optional[Any] = None, **model_kwargs, ) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]: """Custom decoding with DeepCONF (confidence-based early stopping). Args: model: PreTrainedModel with a LM head. input_ids: Prompt ids of shape (batch, seq_len). logits_processor: Optional logits processors. stopping_criteria: Optional stopping criteria. generation_config: GenerationConfig controlling sampling/outputs. synced_gpus: Keep looping to max length for distributed setups. streamer: Optional streamer for incremental tokens. **model_kwargs: Forward pass kwargs (e.g., attention_mask). Returns: GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, or LongTensor depending on `return_dict_in_generate` and model type. """ # Ensure processors/criteria are defined if logits_processor is None: logits_processor = LogitsProcessorList() if stopping_criteria is None: stopping_criteria = StoppingCriteriaList() # Get DeepCONF parameters from generation_config or set defaults enable_conf = getattr(generation_config, "enable_conf", False) window_size = getattr(generation_config, "window_size", 2048) threshold = getattr(generation_config, "threshold", 17.0) # Default threshold for confidence (positive value) # If DeepCONF is not enabled, fall back to standard sampling if not enable_conf: return model._sample( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) # Initialize values # Handle pad token properly (following HF best practices) pad_token_id = generation_config.pad_token_id if pad_token_id is None and hasattr(generation_config, "_pad_token_tensor"): pad_token_id = generation_config._pad_token_tensor if pad_token_id is None and hasattr(model.config, "pad_token_id"): pad_token_id = model.config.pad_token_id if pad_token_id is None and generation_config.eos_token_id is not None: # Use eos token as pad token if not set pad_token_id = generation_config.eos_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate output_confidences = getattr(generation_config, "output_confidences", False) # Optional DeepConf variant helpers (compute threshold from warmup confidences) deepconf_variant = getattr(generation_config, "deepconf_variant", None) # "low" or "high" deepconf_eta = getattr(generation_config, "deepconf_eta", None) # float in (0,1) deepconf_warmup_confidences = getattr(generation_config, "deepconf_warmup_confidences", None) # list/1D tensor has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample # If a variant is requested and a warmup set of confidences is provided, derive the threshold if enable_conf and threshold is not None: pass elif enable_conf and deepconf_variant is not None and deepconf_warmup_confidences is not None: confs = deepconf_warmup_confidences if hasattr(confs, "detach"): confs = confs.detach().cpu().numpy() elif isinstance(confs, torch.Tensor): confs = confs.cpu().numpy() confs = np.asarray(confs, dtype=np.float32).ravel() eta = deepconf_eta if eta is None: eta = 0.1 if deepconf_variant == "low" else 0.9 if deepconf_variant == "high" else 0.5 pct = max(0.0, min(100.0, 100.0 - (eta * 100.0))) threshold = float(np.percentile(confs, pct)) # Initialize attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # If model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and model.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None # Keep track of which sequences are already finished batch_size, cur_len = input_ids.shape[:2] unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) # Use public kv-cache via past_key_values # Initialize confidence tracking # Use deque for sliding window with fixed size conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)] conf_grouped_sums = [0.0 for _ in range(batch_size)] # Running sums for efficient mean calculation # Optional per-step confidences for debugging/visualization step_confidences = [] if (return_dict_in_generate and output_confidences) else None # Main generation loop using public controls steps = 0 max_new_tokens = getattr(generation_config, "max_new_tokens", None) or 512 # Initialize cache_position for first forward over the full prompt # Subsequent steps will pass a single position incrementally model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) while steps < max_new_tokens and unfinished_sequences.max() != 0: # Prepare model inputs (proper KV cache handling) model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) # Prepare variable output controls model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) # Forward pass with proper KV cache handling with torch.no_grad(): outputs = model(**model_inputs, return_dict=True) next_token_logits = outputs.logits[:, -1, :].detach() # Update model kwargs for next iteration (public): carry past_key_values if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None: model_kwargs["past_key_values"] = outputs.past_key_values # Pre-process distribution with logits processors next_token_scores = logits_processor(input_ids, next_token_logits) # Apply logits warpers (e.g., temperature, top-k, top-p) from generation_config warpers = LogitsProcessorList() # Temperature temperature = getattr(generation_config, "temperature", 1.0) if temperature is not None and temperature != 1.0: warpers.append(TemperatureLogitsWarper(temperature)) # Top-k top_k = getattr(generation_config, "top_k", None) if top_k is not None and isinstance(top_k, int) and top_k > 0: warpers.append(TopKLogitsWarper(top_k)) # Top-p top_p = getattr(generation_config, "top_p", None) if top_p is not None and top_p < 1.0: warpers.append(TopPLogitsWarper(top_p)) if len(warpers) > 0: next_token_scores = warpers(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,) ) if model.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,) ) # Token selection if do_sample: probs = F.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) # Calculate confidence using only top-k/top-p filtered candidates (post-logits processors), # excluding the sampled token. # We consider candidates where logits are finite after warpers (e.g., top-k/top-p/temperature). logprobs = F.log_softmax(next_token_scores, dim=-1) candidate_mask = torch.isfinite(next_token_scores) deepconf_stopping = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device) step_conf_values = [0.0] * batch_size # collect per-sequence confidences for this step (full batch) for i in range(batch_size): if not unfinished_sequences[i]: continue # Count valid candidates num_candidates = int(candidate_mask[i].sum().item()) if num_candidates <= 1: conf = 0.0 else: # Sum logprobs over valid candidates and exclude the sampled token's logprob total_lp = torch.sum(logprobs[i][candidate_mask[i]]) selected_lp = ( logprobs[i, next_tokens[i]] if candidate_mask[i, next_tokens[i]] else torch.tensor(0.0, device=logprobs.device) ) denom = num_candidates - 1 # Negative mean of non-selected candidate logprobs conf = -((total_lp - selected_lp) / denom).item() # Update tracking structures if len(conf_group_lists[i]) >= window_size: conf_grouped_sums[i] -= conf_group_lists[i][0] conf_group_lists[i].append(conf) conf_grouped_sums[i] += conf # Apply confidence-based early stopping when window is full if len(conf_group_lists[i]) >= window_size: avg_conf = conf_grouped_sums[i] / len(conf_group_lists[i]) if avg_conf < threshold: deepconf_stopping[i] = False if step_confidences is not None: step_conf_values[i] = conf if step_confidences is not None: # Store this step's confidences as a tensor of shape (batch,) step_confidences.append(torch.tensor(step_conf_values, device=input_ids.device)) # Finished sentences should have their next token be a padding token if has_eos_stopping_criteria and pad_token_id is not None: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # Update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) # Update attention mask if available if model_kwargs.get("attention_mask") is not None: attn = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attn, torch.ones((batch_size, 1), dtype=attn.dtype, device=attn.device)], dim=-1 ) # Update cache_position for next step (single next token) model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 if streamer is not None: streamer.put(next_tokens.cpu()) # Update unfinished sequences with standard stopping criteria (per-sequence if available) sc = stopping_criteria(input_ids, scores) if isinstance(sc, torch.Tensor): unfinished_sequences = unfinished_sequences & ~sc elif sc: # global stop unfinished_sequences = torch.zeros_like(unfinished_sequences) # Apply DeepCONF stopping unfinished_sequences = unfinished_sequences & deepconf_stopping # Early break if all sequences finished and not synchronized if unfinished_sequences.max() == 0 and not synced_gpus: break cur_len += 1 steps += 1 # Clean up outputs to save memory del outputs if streamer is not None: streamer.end() # Return results if return_dict_in_generate: # Prepare confidences tensor if requested confidences_tensor = None if step_confidences is not None and len(step_confidences) > 0: # Shape: (steps, batch) -> (batch, steps) confidences_tensor = torch.stack(step_confidences, dim=0).transpose(0, 1) if model.config.is_encoder_decoder: output = GenerateEncoderDecoderOutput( sequences=input_ids, scores=scores, logits=raw_logits, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) if confidences_tensor is not None: output["confidences"] = confidences_tensor try: setattr(output, "confidences", confidences_tensor) except Exception: pass return output else: output = GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) if confidences_tensor is not None: output["confidences"] = confidences_tensor try: setattr(output, "confidences", confidences_tensor) except Exception: pass return output else: return input_ids