|
from collections import deque |
|
from typing import Any, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList |
|
from transformers.generation.logits_process import ( |
|
TemperatureLogitsWarper, |
|
TopKLogitsWarper, |
|
TopPLogitsWarper, |
|
) |
|
from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput |
|
|
|
|
|
def generate( |
|
model: Any, |
|
input_ids: torch.LongTensor, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
generation_config: Optional[GenerationConfig] = None, |
|
synced_gpus: bool = False, |
|
streamer: Optional[Any] = None, |
|
**model_kwargs, |
|
) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]: |
|
"""Custom decoding with DeepCONF (confidence-based early stopping). |
|
|
|
Args: |
|
model: PreTrainedModel with a LM head. |
|
input_ids: Prompt ids of shape (batch, seq_len). |
|
logits_processor: Optional logits processors. |
|
stopping_criteria: Optional stopping criteria. |
|
generation_config: GenerationConfig controlling sampling/outputs. |
|
synced_gpus: Keep looping to max length for distributed setups. |
|
streamer: Optional streamer for incremental tokens. |
|
**model_kwargs: Forward pass kwargs (e.g., attention_mask). |
|
|
|
Returns: |
|
GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, or LongTensor |
|
depending on `return_dict_in_generate` and model type. |
|
""" |
|
|
|
|
|
if logits_processor is None: |
|
logits_processor = LogitsProcessorList() |
|
if stopping_criteria is None: |
|
stopping_criteria = StoppingCriteriaList() |
|
|
|
|
|
enable_conf = getattr(generation_config, "enable_conf", False) |
|
window_size = getattr(generation_config, "window_size", 2048) |
|
threshold = getattr(generation_config, "threshold", 17.0) |
|
|
|
|
|
if not enable_conf: |
|
return model._sample( |
|
input_ids, |
|
logits_processor=logits_processor, |
|
stopping_criteria=stopping_criteria, |
|
generation_config=generation_config, |
|
synced_gpus=synced_gpus, |
|
streamer=streamer, |
|
**model_kwargs, |
|
) |
|
|
|
|
|
|
|
pad_token_id = generation_config.pad_token_id |
|
if pad_token_id is None and hasattr(generation_config, "_pad_token_tensor"): |
|
pad_token_id = generation_config._pad_token_tensor |
|
if pad_token_id is None and hasattr(model.config, "pad_token_id"): |
|
pad_token_id = model.config.pad_token_id |
|
if pad_token_id is None and generation_config.eos_token_id is not None: |
|
|
|
pad_token_id = generation_config.eos_token_id |
|
|
|
output_attentions = generation_config.output_attentions |
|
output_hidden_states = generation_config.output_hidden_states |
|
output_scores = generation_config.output_scores |
|
output_logits = generation_config.output_logits |
|
return_dict_in_generate = generation_config.return_dict_in_generate |
|
output_confidences = getattr(generation_config, "output_confidences", False) |
|
|
|
deepconf_variant = getattr(generation_config, "deepconf_variant", None) |
|
deepconf_eta = getattr(generation_config, "deepconf_eta", None) |
|
deepconf_warmup_confidences = getattr(generation_config, "deepconf_warmup_confidences", None) |
|
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) |
|
do_sample = generation_config.do_sample |
|
|
|
|
|
if enable_conf and threshold is not None: |
|
pass |
|
elif enable_conf and deepconf_variant is not None and deepconf_warmup_confidences is not None: |
|
confs = deepconf_warmup_confidences |
|
if hasattr(confs, "detach"): |
|
confs = confs.detach().cpu().numpy() |
|
elif isinstance(confs, torch.Tensor): |
|
confs = confs.cpu().numpy() |
|
confs = np.asarray(confs, dtype=np.float32).ravel() |
|
eta = deepconf_eta |
|
if eta is None: |
|
eta = 0.1 if deepconf_variant == "low" else 0.9 if deepconf_variant == "high" else 0.5 |
|
pct = max(0.0, min(100.0, 100.0 - (eta * 100.0))) |
|
threshold = float(np.percentile(confs, pct)) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
raw_logits = () if (return_dict_in_generate and output_logits) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and model.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
|
|
|
|
batch_size, cur_len = input_ids.shape[:2] |
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) |
|
|
|
|
|
|
|
|
|
conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)] |
|
conf_grouped_sums = [0.0 for _ in range(batch_size)] |
|
|
|
|
|
step_confidences = [] if (return_dict_in_generate and output_confidences) else None |
|
|
|
|
|
steps = 0 |
|
max_new_tokens = getattr(generation_config, "max_new_tokens", None) or 512 |
|
|
|
|
|
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) |
|
while steps < max_new_tokens and unfinished_sequences.max() != 0: |
|
|
|
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
|
|
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) |
|
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**model_inputs, return_dict=True) |
|
next_token_logits = outputs.logits[:, -1, :].detach() |
|
|
|
|
|
if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None: |
|
model_kwargs["past_key_values"] = outputs.past_key_values |
|
|
|
|
|
next_token_scores = logits_processor(input_ids, next_token_logits) |
|
|
|
|
|
warpers = LogitsProcessorList() |
|
|
|
temperature = getattr(generation_config, "temperature", 1.0) |
|
if temperature is not None and temperature != 1.0: |
|
warpers.append(TemperatureLogitsWarper(temperature)) |
|
|
|
top_k = getattr(generation_config, "top_k", None) |
|
if top_k is not None and isinstance(top_k, int) and top_k > 0: |
|
warpers.append(TopKLogitsWarper(top_k)) |
|
|
|
top_p = getattr(generation_config, "top_p", None) |
|
if top_p is not None and top_p < 1.0: |
|
warpers.append(TopPLogitsWarper(top_p)) |
|
if len(warpers) > 0: |
|
next_token_scores = warpers(input_ids, next_token_scores) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (next_token_scores,) |
|
if output_logits: |
|
raw_logits += (next_token_logits,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if model.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
if do_sample: |
|
probs = F.softmax(next_token_scores, dim=-1) |
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
else: |
|
next_tokens = torch.argmax(next_token_scores, dim=-1) |
|
|
|
|
|
|
|
|
|
logprobs = F.log_softmax(next_token_scores, dim=-1) |
|
candidate_mask = torch.isfinite(next_token_scores) |
|
|
|
deepconf_stopping = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device) |
|
step_conf_values = [0.0] * batch_size |
|
|
|
for i in range(batch_size): |
|
if not unfinished_sequences[i]: |
|
continue |
|
|
|
|
|
num_candidates = int(candidate_mask[i].sum().item()) |
|
if num_candidates <= 1: |
|
conf = 0.0 |
|
else: |
|
|
|
total_lp = torch.sum(logprobs[i][candidate_mask[i]]) |
|
selected_lp = ( |
|
logprobs[i, next_tokens[i]] |
|
if candidate_mask[i, next_tokens[i]] |
|
else torch.tensor(0.0, device=logprobs.device) |
|
) |
|
denom = num_candidates - 1 |
|
|
|
conf = -((total_lp - selected_lp) / denom).item() |
|
|
|
|
|
if len(conf_group_lists[i]) >= window_size: |
|
conf_grouped_sums[i] -= conf_group_lists[i][0] |
|
conf_group_lists[i].append(conf) |
|
conf_grouped_sums[i] += conf |
|
|
|
|
|
if len(conf_group_lists[i]) >= window_size: |
|
avg_conf = conf_grouped_sums[i] / len(conf_group_lists[i]) |
|
if avg_conf < threshold: |
|
deepconf_stopping[i] = False |
|
|
|
if step_confidences is not None: |
|
step_conf_values[i] = conf |
|
|
|
if step_confidences is not None: |
|
|
|
step_confidences.append(torch.tensor(step_conf_values, device=input_ids.device)) |
|
|
|
|
|
if has_eos_stopping_criteria and pad_token_id is not None: |
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
|
|
if model_kwargs.get("attention_mask") is not None: |
|
attn = model_kwargs["attention_mask"] |
|
model_kwargs["attention_mask"] = torch.cat( |
|
[attn, torch.ones((batch_size, 1), dtype=attn.dtype, device=attn.device)], dim=-1 |
|
) |
|
|
|
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 |
|
if streamer is not None: |
|
streamer.put(next_tokens.cpu()) |
|
|
|
|
|
sc = stopping_criteria(input_ids, scores) |
|
if isinstance(sc, torch.Tensor): |
|
unfinished_sequences = unfinished_sequences & ~sc |
|
elif sc: |
|
|
|
unfinished_sequences = torch.zeros_like(unfinished_sequences) |
|
|
|
|
|
unfinished_sequences = unfinished_sequences & deepconf_stopping |
|
|
|
|
|
if unfinished_sequences.max() == 0 and not synced_gpus: |
|
break |
|
cur_len += 1 |
|
steps += 1 |
|
|
|
|
|
del outputs |
|
|
|
if streamer is not None: |
|
streamer.end() |
|
|
|
|
|
if return_dict_in_generate: |
|
|
|
confidences_tensor = None |
|
if step_confidences is not None and len(step_confidences) > 0: |
|
|
|
confidences_tensor = torch.stack(step_confidences, dim=0).transpose(0, 1) |
|
if model.config.is_encoder_decoder: |
|
output = GenerateEncoderDecoderOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
logits=raw_logits, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get("past_key_values"), |
|
) |
|
if confidences_tensor is not None: |
|
output["confidences"] = confidences_tensor |
|
try: |
|
setattr(output, "confidences", confidences_tensor) |
|
except Exception: |
|
pass |
|
return output |
|
else: |
|
output = GenerateDecoderOnlyOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
logits=raw_logits, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
past_key_values=model_kwargs.get("past_key_values"), |
|
) |
|
if confidences_tensor is not None: |
|
output["confidences"] = confidences_tensor |
|
try: |
|
setattr(output, "confidences", confidences_tensor) |
|
except Exception: |
|
pass |
|
return output |
|
else: |
|
return input_ids |
|
|