Transformers
custom_generate
sampling
DeepConf / custom_generate /generate.py
kashif's picture
kashif HF Staff
Update custom_generate/generate.py
93f1f4c verified
from collections import deque
from typing import Any, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.generation.logits_process import (
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput
def generate(
model: Any,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
generation_config: Optional[GenerationConfig] = None,
synced_gpus: bool = False,
streamer: Optional[Any] = None,
**model_kwargs,
) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]:
"""Custom decoding with DeepCONF (confidence-based early stopping).
Args:
model: PreTrainedModel with a LM head.
input_ids: Prompt ids of shape (batch, seq_len).
logits_processor: Optional logits processors.
stopping_criteria: Optional stopping criteria.
generation_config: GenerationConfig controlling sampling/outputs.
synced_gpus: Keep looping to max length for distributed setups.
streamer: Optional streamer for incremental tokens.
**model_kwargs: Forward pass kwargs (e.g., attention_mask).
Returns:
GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, or LongTensor
depending on `return_dict_in_generate` and model type.
"""
# Ensure processors/criteria are defined
if logits_processor is None:
logits_processor = LogitsProcessorList()
if stopping_criteria is None:
stopping_criteria = StoppingCriteriaList()
# Get DeepCONF parameters from generation_config or set defaults
enable_conf = getattr(generation_config, "enable_conf", False)
window_size = getattr(generation_config, "window_size", 2048)
threshold = getattr(generation_config, "threshold", 17.0) # Default threshold for confidence (positive value)
# If DeepCONF is not enabled, fall back to standard sampling
if not enable_conf:
return model._sample(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
# Initialize values
# Handle pad token properly (following HF best practices)
pad_token_id = generation_config.pad_token_id
if pad_token_id is None and hasattr(generation_config, "_pad_token_tensor"):
pad_token_id = generation_config._pad_token_tensor
if pad_token_id is None and hasattr(model.config, "pad_token_id"):
pad_token_id = model.config.pad_token_id
if pad_token_id is None and generation_config.eos_token_id is not None:
# Use eos token as pad token if not set
pad_token_id = generation_config.eos_token_id
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
output_confidences = getattr(generation_config, "output_confidences", False)
# Optional DeepConf variant helpers (compute threshold from warmup confidences)
deepconf_variant = getattr(generation_config, "deepconf_variant", None) # "low" or "high"
deepconf_eta = getattr(generation_config, "deepconf_eta", None) # float in (0,1)
deepconf_warmup_confidences = getattr(generation_config, "deepconf_warmup_confidences", None) # list/1D tensor
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
# If a variant is requested and a warmup set of confidences is provided, derive the threshold
if enable_conf and threshold is not None:
pass
elif enable_conf and deepconf_variant is not None and deepconf_warmup_confidences is not None:
confs = deepconf_warmup_confidences
if hasattr(confs, "detach"):
confs = confs.detach().cpu().numpy()
elif isinstance(confs, torch.Tensor):
confs = confs.cpu().numpy()
confs = np.asarray(confs, dtype=np.float32).ravel()
eta = deepconf_eta
if eta is None:
eta = 0.1 if deepconf_variant == "low" else 0.9 if deepconf_variant == "high" else 0.5
pct = max(0.0, min(100.0, 100.0 - (eta * 100.0)))
threshold = float(np.percentile(confs, pct))
# Initialize attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# If model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and model.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
# Keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
# Use public kv-cache via past_key_values
# Initialize confidence tracking
# Use deque for sliding window with fixed size
conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)]
conf_grouped_sums = [0.0 for _ in range(batch_size)] # Running sums for efficient mean calculation
# Optional per-step confidences for debugging/visualization
step_confidences = [] if (return_dict_in_generate and output_confidences) else None
# Main generation loop using public controls
steps = 0
max_new_tokens = getattr(generation_config, "max_new_tokens", None) or 512
# Initialize cache_position for first forward over the full prompt
# Subsequent steps will pass a single position incrementally
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
while steps < max_new_tokens and unfinished_sequences.max() != 0:
# Prepare model inputs (proper KV cache handling)
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
# Prepare variable output controls
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
# Forward pass with proper KV cache handling
with torch.no_grad():
outputs = model(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :].detach()
# Update model kwargs for next iteration (public): carry past_key_values
if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None:
model_kwargs["past_key_values"] = outputs.past_key_values
# Pre-process distribution with logits processors
next_token_scores = logits_processor(input_ids, next_token_logits)
# Apply logits warpers (e.g., temperature, top-k, top-p) from generation_config
warpers = LogitsProcessorList()
# Temperature
temperature = getattr(generation_config, "temperature", 1.0)
if temperature is not None and temperature != 1.0:
warpers.append(TemperatureLogitsWarper(temperature))
# Top-k
top_k = getattr(generation_config, "top_k", None)
if top_k is not None and isinstance(top_k, int) and top_k > 0:
warpers.append(TopKLogitsWarper(top_k))
# Top-p
top_p = getattr(generation_config, "top_p", None)
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p))
if len(warpers) > 0:
next_token_scores = warpers(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
)
if model.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,)
)
# Token selection
if do_sample:
probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# Calculate confidence using only top-k/top-p filtered candidates (post-logits processors),
# excluding the sampled token.
# We consider candidates where logits are finite after warpers (e.g., top-k/top-p/temperature).
logprobs = F.log_softmax(next_token_scores, dim=-1)
candidate_mask = torch.isfinite(next_token_scores)
deepconf_stopping = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device)
step_conf_values = [0.0] * batch_size # collect per-sequence confidences for this step (full batch)
for i in range(batch_size):
if not unfinished_sequences[i]:
continue
# Count valid candidates
num_candidates = int(candidate_mask[i].sum().item())
if num_candidates <= 1:
conf = 0.0
else:
# Sum logprobs over valid candidates and exclude the sampled token's logprob
total_lp = torch.sum(logprobs[i][candidate_mask[i]])
selected_lp = (
logprobs[i, next_tokens[i]]
if candidate_mask[i, next_tokens[i]]
else torch.tensor(0.0, device=logprobs.device)
)
denom = num_candidates - 1
# Negative mean of non-selected candidate logprobs
conf = -((total_lp - selected_lp) / denom).item()
# Update tracking structures
if len(conf_group_lists[i]) >= window_size:
conf_grouped_sums[i] -= conf_group_lists[i][0]
conf_group_lists[i].append(conf)
conf_grouped_sums[i] += conf
# Apply confidence-based early stopping when window is full
if len(conf_group_lists[i]) >= window_size:
avg_conf = conf_grouped_sums[i] / len(conf_group_lists[i])
if avg_conf < threshold:
deepconf_stopping[i] = False
if step_confidences is not None:
step_conf_values[i] = conf
if step_confidences is not None:
# Store this step's confidences as a tensor of shape (batch,)
step_confidences.append(torch.tensor(step_conf_values, device=input_ids.device))
# Finished sentences should have their next token be a padding token
if has_eos_stopping_criteria and pad_token_id is not None:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# Update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# Update attention mask if available
if model_kwargs.get("attention_mask") is not None:
attn = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attn, torch.ones((batch_size, 1), dtype=attn.dtype, device=attn.device)], dim=-1
)
# Update cache_position for next step (single next token)
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
if streamer is not None:
streamer.put(next_tokens.cpu())
# Update unfinished sequences with standard stopping criteria (per-sequence if available)
sc = stopping_criteria(input_ids, scores)
if isinstance(sc, torch.Tensor):
unfinished_sequences = unfinished_sequences & ~sc
elif sc:
# global stop
unfinished_sequences = torch.zeros_like(unfinished_sequences)
# Apply DeepCONF stopping
unfinished_sequences = unfinished_sequences & deepconf_stopping
# Early break if all sequences finished and not synchronized
if unfinished_sequences.max() == 0 and not synced_gpus:
break
cur_len += 1
steps += 1
# Clean up outputs to save memory
del outputs
if streamer is not None:
streamer.end()
# Return results
if return_dict_in_generate:
# Prepare confidences tensor if requested
confidences_tensor = None
if step_confidences is not None and len(step_confidences) > 0:
# Shape: (steps, batch) -> (batch, steps)
confidences_tensor = torch.stack(step_confidences, dim=0).transpose(0, 1)
if model.config.is_encoder_decoder:
output = GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
if confidences_tensor is not None:
output["confidences"] = confidences_tensor
try:
setattr(output, "confidences", confidences_tensor)
except Exception:
pass
return output
else:
output = GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
if confidences_tensor is not None:
output["confidences"] = confidences_tensor
try:
setattr(output, "confidences", confidences_tensor)
except Exception:
pass
return output
else:
return input_ids