Spaces:
Runtime error
Runtime error
from transformers.models.bart import BartForConditionalGeneration | |
import torch | |
from transformers.generation_beam_search import BeamScorer | |
from abc import ABC, abstractmethod | |
from collections import UserDict | |
from typing import Optional, Tuple, Union, Dict, Any | |
from transformers.generation_logits_process import LogitsProcessorList | |
from transformers.generation_utils import BeamSearchEncoderDecoderOutput,BeamSearchDecoderOnlyOutput | |
from torch.nn import functional as F | |
from transformers.file_utils import ModelOutput | |
import torch.nn | |
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] | |
class BartForConditionalGeneration_GroupBeam(BartForConditionalGeneration): | |
def beam_search( | |
self, | |
input_ids: torch.LongTensor, | |
beam_scorer: BeamScorer, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
**model_kwargs, | |
) -> Union[BeamSearchOutput, torch.LongTensor]: | |
r""" | |
Generates sequences for models with a language modeling head using beam search decoding. | |
Parameters: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty | |
:obj:`torch.LongTensor` of shape :obj:`(1,)`. | |
beam_scorer (:obj:`BeamScorer`): | |
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are | |
constructed, stored and sorted during generation. For more information, the documentation of | |
:class:`~transformers.BeamScorer` should be read. | |
logits_processor (:obj:`LogitsProcessorList`, `optional`): | |
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from | |
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling | |
head applied at each generation step. | |
max_length (:obj:`int`, `optional`, defaults to 20): | |
The maximum length of the sequence to be generated. | |
pad_token_id (:obj:`int`, `optional`): | |
The id of the `padding` token. | |
eos_token_id (:obj:`int`, `optional`): | |
The id of the `end-of-sequence` token. | |
output_attentions (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under | |
returned tensors for more details. | |
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors | |
for more details. | |
output_scores (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. | |
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |
model_kwargs: | |
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If | |
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. | |
Return: | |
:class:`~transformers.generation_utilsBeamSearchDecoderOnlyOutput`, | |
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A | |
:obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a | |
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if | |
``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a | |
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if | |
``model.config.is_encoder_decoder=True``. | |
Examples:: | |
>>> from transformers import ( | |
... AutoTokenizer, | |
... AutoModelForSeq2SeqLM, | |
... LogitsProcessorList, | |
... MinLengthLogitsProcessor, | |
... BeamSearchScorer, | |
... ) | |
>>> import torch | |
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") | |
>>> encoder_input_str = "translate English to German: How old are you?" | |
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids | |
>>> # lets run beam search using 3 beams | |
>>> num_beams = 3 | |
>>> # define decoder start token ids | |
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) | |
>>> input_ids = input_ids * model.config.decoder_start_token_id | |
>>> # add encoder_outputs to model keyword arguments | |
>>> model_kwargs = { | |
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) | |
... } | |
>>> # instantiate beam scorer | |
>>> beam_scorer = BeamSearchScorer( | |
... batch_size=1, | |
... max_length=model.config.max_length, | |
... num_beams=num_beams, | |
... device=model.device, | |
... ) | |
>>> # instantiate logits processors | |
>>> logits_processor = LogitsProcessorList([ | |
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), | |
... ]) | |
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) | |
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
""" | |
# init values | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
max_length = max_length if max_length is not None else self.config.max_length | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
output_scores = output_scores if output_scores is not None else self.config.output_scores | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
batch_size = len(beam_scorer._beam_hyps) | |
num_beams = beam_scorer.num_beams | |
batch_beam_size, cur_len = input_ids.shape | |
assert ( | |
num_beams * batch_size == batch_beam_size | |
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
beam_scores[:, 1:] = -1e9 | |
beam_scores = beam_scores.view((batch_size * num_beams,)) | |
while cur_len < max_length: | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
next_token_logits = outputs.logits[:, -1, :] | |
# adjust tokens for Bart, *e.g.* | |
next_token_logits = self.adjust_logits_during_generation( | |
next_token_logits, cur_len=cur_len, max_length=max_length | |
) | |
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) | |
next_token_scores = logits_processor(input_ids, next_token_scores) | |
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (next_token_scores,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
# reshape for beam search | |
vocab_size = next_token_scores.shape[-1] | |
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) | |
#m = torch.nn.LayerNorm(num_beams * vocab_size) | |
#next_token_scores = m(next_token_scores) | |
next_token_scores_group = torch.sum(next_token_scores,dim=0,keepdim=True).expand(batch_size,-1) / batch_size | |
for i in range(next_token_scores.size(0)): | |
'''tmin = torch.min(next_token_scores_group[i]) | |
for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])): | |
next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin''' | |
for t in model_kwargs['decoder_ori_input_ids'][i]: | |
for j in range(num_beams): | |
#if t not in input_ids[i] or t==1: | |
next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t] | |
next_token_scores, next_tokens = torch.topk( | |
next_token_scores_group, 2 * num_beams, dim=1, largest=True, sorted=True) | |
'''next_token_scores_group = next_token_scores_group.expand(batch_size,-1) | |
next_tokens_group = next_tokens_group.expand(batch_size,-1) | |
next_token_scores, next_tokens = torch.topk( | |
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True | |
) | |
for i in range(next_token_scores.size(0)): | |
j1 = 0 | |
for j in range(next_token_scores.size(1)): | |
if next_tokens[i][j] not in model_kwargs['decoder_ori_input_ids'][i]: | |
next_tokens[i][j] = next_tokens_group[i][j1] | |
j1 += 1 | |
next_token_scores = next_token_scores_group | |
del next_token_scores_group, next_tokens_group''' | |
next_indices = next_tokens // vocab_size | |
next_tokens = next_tokens % vocab_size | |
# stateless | |
beam_outputs = beam_scorer.process( | |
input_ids, | |
next_token_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
) | |
beam_scores = beam_outputs["next_beam_scores"] | |
beam_next_tokens = beam_outputs["next_beam_tokens"] | |
beam_idx = beam_outputs["next_beam_indices"] | |
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
cur_len = cur_len + 1 | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
if model_kwargs["past"] is not None: | |
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) | |
if beam_scorer.is_done: | |
break | |
sequence_outputs = beam_scorer.finalize( | |
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id | |
) | |
if return_dict_in_generate: | |
if not output_scores: | |
sequence_outputs["sequence_scores"] = None | |
if self.config.is_encoder_decoder: | |
return BeamSearchEncoderDecoderOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
encoder_attentions=encoder_attentions, | |
encoder_hidden_states=encoder_hidden_states, | |
decoder_attentions=decoder_attentions, | |
decoder_hidden_states=decoder_hidden_states, | |
) | |
else: | |
return BeamSearchDecoderOnlyOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
) | |
else: | |
return sequence_outputs["sequences"] | |
def group_beam_search( | |
self, | |
input_ids: torch.LongTensor, | |
beam_scorer: BeamScorer, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
**model_kwargs, | |
): | |
r""" | |
Generates sequences for models with a language modeling head using beam search decoding. | |
Parameters: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty | |
:obj:`torch.LongTensor` of shape :obj:`(1,)`. | |
beam_scorer (:obj:`BeamScorer`): | |
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are | |
constructed, stored and sorted during generation. For more information, the documentation of | |
:class:`~transformers.BeamScorer` should be read. | |
logits_processor (:obj:`LogitsProcessorList`, `optional`): | |
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from | |
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling | |
head applied at each generation step. | |
max_length (:obj:`int`, `optional`, defaults to 20): | |
The maximum length of the sequence to be generated. | |
pad_token_id (:obj:`int`, `optional`): | |
The id of the `padding` token. | |
eos_token_id (:obj:`int`, `optional`): | |
The id of the `end-of-sequence` token. | |
output_attentions (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under | |
returned tensors for more details. | |
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors | |
for more details. | |
output_scores (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. | |
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): | |
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |
model_kwargs: | |
Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If | |
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. | |
Return: | |
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`, | |
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A | |
:obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a | |
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if | |
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if | |
``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a | |
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if | |
``model.config.is_encoder_decoder=True``. | |
Examples:: | |
>>> from transformers import ( | |
... AutoTokenizer, | |
... AutoModelForSeq2SeqLM, | |
... LogitsProcessorList, | |
... MinLengthLogitsProcessor, | |
... HammingDiversityLogitsProcessor, | |
... BeamSearchScorer, | |
... ) | |
>>> import torch | |
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") | |
>>> encoder_input_str = "translate English to German: How old are you?" | |
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids | |
>>> # lets run diverse beam search using 6 beams | |
>>> num_beams = 6 | |
>>> # define decoder start token ids | |
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) | |
>>> input_ids = input_ids * model.config.decoder_start_token_id | |
>>> # add encoder_outputs to model keyword arguments | |
>>> model_kwargs = { | |
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) | |
... } | |
>>> # instantiate beam scorer | |
>>> beam_scorer = BeamSearchScorer( | |
... batch_size=1, | |
... max_length=model.config.max_length, | |
... num_beams=num_beams, | |
... device=model.device, | |
... num_beam_groups=3 | |
... ) | |
>>> # instantiate logits processors | |
>>> logits_processor = LogitsProcessorList([ | |
... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), | |
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), | |
... ]) | |
>>> outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) | |
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
""" | |
# init values | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
max_length = max_length if max_length is not None else self.config.max_length | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
output_scores = output_scores if output_scores is not None else self.config.output_scores | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
batch_size = len(beam_scorer._beam_hyps) | |
num_beams = beam_scorer.num_beams | |
num_beam_groups = beam_scorer.num_beam_groups | |
num_sub_beams = num_beams // num_beam_groups | |
device = input_ids.device | |
batch_beam_size, cur_len = input_ids.shape | |
assert ( | |
num_beams * batch_size == batch_beam_size | |
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) | |
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in | |
# the same group don't produce same tokens everytime. | |
beam_scores[:, ::num_sub_beams] = 0 | |
beam_scores = beam_scores.view((batch_size * num_beams,)) | |
while cur_len < max_length: | |
# predicted tokens in cur_len step | |
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) | |
# indices which will form the beams in the next time step | |
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) | |
# do one decoder step on all beams of all sentences in batch | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
for beam_group_idx in range(num_beam_groups): | |
group_start_idx = beam_group_idx * num_sub_beams | |
group_end_idx = min(group_start_idx + num_sub_beams, num_beams) | |
group_size = group_end_idx - group_start_idx | |
# indices of beams of current group among all sentences in batch | |
batch_group_indices = [] | |
if output_scores: | |
processed_score = torch.zeros_like(outputs.logits[:, -1, :]).half() # .float() | |
for batch_idx in range(batch_size): | |
batch_group_indices.extend( | |
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] | |
) | |
group_input_ids = input_ids[batch_group_indices] | |
# select outputs of beams of current group only | |
next_token_logits = outputs.logits[batch_group_indices, -1, :] | |
# adjust tokens for Bart, *e.g.* | |
next_token_logits = self.adjust_logits_during_generation( | |
next_token_logits, cur_len=cur_len, max_length=max_length | |
) | |
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) | |
vocab_size = next_token_scores.shape[-1] | |
next_token_scores = logits_processor( | |
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx | |
) | |
next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as( | |
next_token_scores | |
) | |
if output_scores: | |
processed_score[batch_group_indices] = next_token_scores.half() # .float() | |
# reshape for beam search | |
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) | |
### | |
next_token_scores_group = torch.sum(next_token_scores, dim=0, keepdim=True).expand(batch_size, | |
-1) / batch_size | |
for i in range(next_token_scores.size(0)): | |
'''tmin = torch.min(next_token_scores_group[i]) | |
for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])): | |
next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin''' | |
for t in model_kwargs['decoder_ori_input_ids'][i]: | |
for j in range(group_size): | |
# if t not in input_ids[i] or t==1: | |
next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t] | |
next_token_scores, next_tokens = torch.topk( | |
next_token_scores_group, 2 * group_size, dim=1, largest=True, sorted=True) | |
### | |
#next_token_scores, next_tokens = torch.topk( | |
# next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True | |
#) | |
next_indices = next_tokens // vocab_size | |
next_tokens = next_tokens % vocab_size | |
# stateless | |
beam_outputs = beam_scorer.process( | |
group_input_ids, | |
next_token_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
) | |
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] | |
beam_next_tokens = beam_outputs["next_beam_tokens"] | |
beam_idx = beam_outputs["next_beam_indices"] | |
input_ids[batch_group_indices] = group_input_ids[beam_idx] | |
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
current_tokens[batch_group_indices] = group_input_ids[:, -1] | |
# (beam_idx // group_size) -> batch_idx | |
# (beam_idx % group_size) -> offset of idx inside the group | |
reordering_indices[batch_group_indices] = ( | |
num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size) | |
) | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (processed_score,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
if model_kwargs["past"] is not None: | |
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices) | |
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) | |
cur_len = cur_len + 1 | |
if beam_scorer.is_done: | |
break | |
sequence_outputs = beam_scorer.finalize( | |
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=max_length, | |
) | |
if return_dict_in_generate: | |
if not output_scores: | |
sequence_outputs["sequence_scores"] | |
if self.config.is_encoder_decoder: | |
return BeamSearchEncoderDecoderOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
encoder_attentions=encoder_attentions, | |
encoder_hidden_states=encoder_hidden_states, | |
decoder_attentions=decoder_attentions, | |
decoder_hidden_states=decoder_hidden_states, | |
) | |
else: | |
return BeamSearchDecoderOnlyOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
) | |
else: | |
return sequence_outputs["sequences"] | |