File size: 14,125 Bytes
a42769a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 |
from collections import deque
from typing import Any, Optional, Union
import torch
import torch.nn.functional as F
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.generation.logits_process import (
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput
def generate(
model: Any,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
generation_config: Optional[GenerationConfig] = None,
synced_gpus: bool = False,
streamer: Optional[Any] = None,
**model_kwargs,
) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]:
"""Custom decoding with DeepCONF (confidence-based early stopping).
Args:
model: PreTrainedModel with a LM head.
input_ids: Prompt ids of shape (batch, seq_len).
logits_processor: Optional logits processors.
stopping_criteria: Optional stopping criteria.
generation_config: GenerationConfig controlling sampling/outputs.
synced_gpus: Keep looping to max length for distributed setups.
streamer: Optional streamer for incremental tokens.
**model_kwargs: Forward pass kwargs (e.g., attention_mask).
Returns:
GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, or LongTensor
depending on `return_dict_in_generate` and model type.
"""
# Get DeepCONF parameters from generation_config or set defaults
enable_conf = getattr(generation_config, "enable_conf", False)
window_size = getattr(generation_config, "window_size", 2048)
threshold = getattr(generation_config, "threshold", 17.0) # Default threshold for confidence (positive value)
# If DeepCONF is not enabled, fall back to standard sampling
if not enable_conf:
return model._sample(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
# Initialize values
# Handle pad token properly (following HF best practices)
pad_token_id = generation_config.pad_token_id
if pad_token_id is None and hasattr(generation_config, "_pad_token_tensor"):
pad_token_id = generation_config._pad_token_tensor
if pad_token_id is None and hasattr(model.config, "pad_token_id"):
pad_token_id = model.config.pad_token_id
if pad_token_id is None and generation_config.eos_token_id is not None:
# Use eos token as pad token if not set
pad_token_id = generation_config.eos_token_id
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
output_confidences = getattr(generation_config, "output_confidences", False)
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
# Initialize attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# If model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and model.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
# Keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
# Use public kv-cache via past_key_values
# Initialize confidence tracking
# Use deque for sliding window with fixed size
conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)]
conf_grouped_sums = [0.0 for _ in range(batch_size)] # Running sums for efficient mean calculation
# Initialize via prepare_inputs_for_generation
# Optional per-step confidences for debugging/visualization
step_confidences = [] if (return_dict_in_generate and output_confidences) else None
# Main generation loop using public controls
steps = 0
max_new_tokens = getattr(generation_config, "max_new_tokens", None) or 512
# Initialize cache_position for first forward over the full prompt
# Subsequent steps will pass a single position incrementally
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
while steps < max_new_tokens and unfinished_sequences.max() != 0:
# Prepare model inputs (proper KV cache handling)
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
# Prepare variable output controls
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
# Forward pass with proper KV cache handling
with torch.no_grad():
outputs = model(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :].detach()
# Update model kwargs for next iteration (public): carry past_key_values
if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None:
model_kwargs["past_key_values"] = outputs.past_key_values
# Pre-process distribution with logits processors
next_token_scores = logits_processor(input_ids, next_token_logits)
# Apply logits warpers (e.g., temperature, top-k, top-p) from generation_config
warpers = LogitsProcessorList()
# Temperature
temperature = getattr(generation_config, "temperature", 1.0)
if temperature is not None and temperature != 1.0:
warpers.append(TemperatureLogitsWarper(temperature))
# Top-k
top_k = getattr(generation_config, "top_k", None)
if top_k is not None and isinstance(top_k, int) and top_k > 0:
warpers.append(TopKLogitsWarper(top_k))
# Top-p
top_p = getattr(generation_config, "top_p", None)
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p))
if len(warpers) > 0:
next_token_scores = warpers(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
)
if model.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,)
)
# Token selection
if do_sample:
probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# Calculate confidence using only top-k/top-p filtered candidates (post-logits processors),
# excluding the sampled token.
# We consider candidates where logits are finite after warpers (e.g., top-k/top-p/temperature).
logprobs = F.log_softmax(next_token_scores, dim=-1)
candidate_mask = torch.isfinite(next_token_scores)
deepconf_stopping = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device)
step_conf_values = [0.0] * batch_size # collect per-sequence confidences for this step (full batch)
for i in range(batch_size):
if not unfinished_sequences[i]:
continue
# Count valid candidates
num_candidates = int(candidate_mask[i].sum().item())
if num_candidates <= 1:
conf = 0.0
else:
# Sum logprobs over valid candidates and exclude the sampled token's logprob
total_lp = torch.sum(logprobs[i][candidate_mask[i]])
selected_lp = (
logprobs[i, next_tokens[i]]
if candidate_mask[i, next_tokens[i]]
else torch.tensor(0.0, device=logprobs.device)
)
denom = num_candidates - 1
# Negative mean of non-selected candidate logprobs
conf = -((total_lp - selected_lp) / denom).item()
# Update tracking structures
if len(conf_group_lists[i]) >= window_size:
conf_grouped_sums[i] -= conf_group_lists[i][0]
conf_group_lists[i].append(conf)
conf_grouped_sums[i] += conf
# Apply confidence-based early stopping when window is full
if len(conf_group_lists[i]) >= window_size:
avg_conf = conf_grouped_sums[i] / len(conf_group_lists[i])
if avg_conf < threshold:
deepconf_stopping[i] = False
if step_confidences is not None:
step_conf_values[i] = conf
if step_confidences is not None:
# Store this step's confidences as a tensor of shape (batch,)
step_confidences.append(torch.tensor(step_conf_values, device=input_ids.device))
# Finished sentences should have their next token be a padding token
if has_eos_stopping_criteria and pad_token_id is not None:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# Update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# Update attention mask if available
if model_kwargs.get("attention_mask") is not None:
attn = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attn, torch.ones((batch_size, 1), dtype=attn.dtype, device=attn.device)], dim=-1
)
# Update cache_position for next step (single next token)
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
if streamer is not None:
streamer.put(next_tokens.cpu())
# Update unfinished sequences with standard stopping criteria (per-sequence if available)
sc = stopping_criteria(input_ids, scores)
if isinstance(sc, torch.Tensor):
unfinished_sequences = unfinished_sequences & ~sc
elif sc:
# global stop
unfinished_sequences = torch.zeros_like(unfinished_sequences)
# Apply DeepCONF stopping
unfinished_sequences = unfinished_sequences & deepconf_stopping
# Early break if all sequences finished and not synchronized
if unfinished_sequences.max() == 0 and not synced_gpus:
break
cur_len += 1
steps += 1
# Clean up outputs to save memory
del outputs
if streamer is not None:
streamer.end()
# Return results
if return_dict_in_generate:
# Prepare confidences tensor if requested
confidences_tensor = None
if step_confidences is not None and len(step_confidences) > 0:
# Shape: (steps, batch) -> (batch, steps)
confidences_tensor = torch.stack(step_confidences, dim=0).transpose(0, 1)
if model.config.is_encoder_decoder:
output = GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
if confidences_tensor is not None:
output["confidences"] = confidences_tensor
try:
setattr(output, "confidences", confidences_tensor)
except Exception:
pass
return output
else:
output = GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
if confidences_tensor is not None:
output["confidences"] = confidences_tensor
try:
setattr(output, "confidences", confidences_tensor)
except Exception:
pass
return output
else:
return input_ids
|