Spaces:
Runtime error
Runtime error
import copy | |
import inspect | |
import warnings | |
from dataclasses import dataclass | |
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.distributed as dist | |
from torch import nn | |
from transformers.generation.logits_process import ( | |
LogitsProcessorList, | |
) | |
from transformers.generation.stopping_criteria import ( | |
StoppingCriteria, | |
StoppingCriteriaList, | |
validate_stopping_criteria, | |
) | |
import transformers | |
from transformers.generation.utils import SampleOutput | |
def sample( | |
self, | |
input_ids: torch.LongTensor, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
logits_warper: Optional[LogitsProcessorList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[Union[int, List[int]]] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
synced_gpus: bool = False, | |
streamer: Optional["BaseStreamer"] = None, | |
**model_kwargs, | |
) -> Union[SampleOutput, torch.LongTensor]: | |
# init values | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
if max_length is not None: | |
warnings.warn( | |
"`max_length` is deprecated in this function, use" | |
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | |
UserWarning, | |
) | |
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() | |
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores | |
output_attentions = ( | |
output_attentions if output_attentions is not None else self.generation_config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate | |
if return_dict_in_generate is not None | |
else self.generation_config.return_dict_in_generate | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
# keep track of which sequences are already finished | |
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
this_peer_finished = False # used by synced_gpus only | |
# auto-regressive generation | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
# prepare model inputs | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
if synced_gpus and this_peer_finished: | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
## For contrastive decoding initial | |
use_cd = model_kwargs.get("images_cd") != None | |
output_attentions_wo_img = ( | |
output_attentions if output_attentions is not None else self.generation_config.output_attentions | |
) | |
output_hidden_states_wo_img = ( | |
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states | |
) | |
model_kwargs_cd = model_kwargs.copy() | |
if use_cd: | |
## cd_comments: forward pass of the model with distorted image input | |
model_inputs_cd = self.prepare_inputs_for_generation_cd(input_ids, **model_kwargs_cd) | |
outputs_cd = self( | |
**model_inputs_cd, | |
return_dict=True, | |
output_attentions=output_attentions_wo_img, | |
output_hidden_states=output_hidden_states_wo_img, | |
) | |
next_token_logits_cd = outputs_cd.logits[:, -1, :] | |
## cd_comments: pre-process logits from contrastive inputs | |
cd_alpha = model_kwargs.get("cd_alpha") if model_kwargs.get("cd_alpha") is not None else 0.5 | |
cd_beta = model_kwargs.get("cd_beta") if model_kwargs.get("cd_beta") is not None else 0.1 | |
# version 1 set cutoff for Adaptive Plausibility Constraints | |
# probs = nn.functional.softmax(next_token_logits, dim=-1) | |
# cutoff = cd_beta * probs.max(dim=-1, keepdim=True).values | |
# version 2 set cutoff for Adaptive Plausibility Constraints | |
cutoff = torch.log(torch.tensor(cd_beta)) + next_token_logits.max(dim=-1, keepdim=True).values | |
diffs = (1+cd_alpha)*next_token_logits - cd_alpha*next_token_logits_cd | |
cd_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf")) | |
## cd_comments: apply temperature warping and top-k filtering in contrastive decoding | |
cd_logits = logits_processor(input_ids, cd_logits) | |
cd_logits = logits_warper(input_ids, cd_logits) | |
next_token_scores = cd_logits | |
cd_probs = nn.functional.softmax(cd_logits, dim=-1) | |
next_tokens = torch.multinomial(cd_probs, num_samples=1).squeeze(1) | |
else: | |
next_token_scores = logits_processor(input_ids, next_token_logits) | |
next_token_scores = logits_warper(input_ids, next_token_scores) | |
probs = nn.functional.softmax(next_token_scores, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (next_token_scores,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if self.config.is_encoder_decoder: | |
cross_attentions += (outputs.cross_attentions,) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
if pad_token_id is None: | |
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# update generated ids, model inputs, and length for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
if streamer is not None: | |
streamer.put(next_tokens.cpu()) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
## cd_comments: update model_kwargs_cd for contrastive decoding | |
if use_cd: | |
model_kwargs_cd = self._update_model_kwargs_for_generation( | |
outputs_cd, model_kwargs_cd, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id_tensor is not None: | |
unfinished_sequences = unfinished_sequences.mul( | |
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
) | |
# stop when each sentence is finished | |
if unfinished_sequences.max() == 0: | |
this_peer_finished = True | |
# stop if we exceed the maximum length | |
if stopping_criteria(input_ids, scores): | |
this_peer_finished = True | |
if this_peer_finished and not synced_gpus: | |
break | |
if streamer is not None: | |
streamer.end() | |
if return_dict_in_generate: | |
if self.config.is_encoder_decoder: | |
return SampleEncoderDecoderOutput( | |
sequences=input_ids, | |
scores=scores, | |
encoder_attentions=encoder_attentions, | |
encoder_hidden_states=encoder_hidden_states, | |
decoder_attentions=decoder_attentions, | |
cross_attentions=cross_attentions, | |
decoder_hidden_states=decoder_hidden_states, | |
) | |
else: | |
return SampleDecoderOnlyOutput( | |
sequences=input_ids, | |
scores=scores, | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
) | |
else: | |
return input_ids | |
def evolve_vcd_sampling(): | |
transformers.generation.utils.GenerationMixin.sample = sample |