orion / src /bart_with_group_beam.py
andreslu's picture
Upload 25 files
0f14897
raw
history blame
30.6 kB
from transformers.models.bart import BartForConditionalGeneration
import torch
from transformers.generation_beam_search import BeamScorer
from abc import ABC, abstractmethod
from collections import UserDict
from typing import Optional, Tuple, Union, Dict, Any
from transformers.generation_logits_process import LogitsProcessorList
from transformers.generation_utils import BeamSearchEncoderDecoderOutput,BeamSearchDecoderOnlyOutput
from torch.nn import functional as F
from transformers.file_utils import ModelOutput
import torch.nn
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
class BartForConditionalGeneration_GroupBeam(BartForConditionalGeneration):
def beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
**model_kwargs,
) -> Union[BeamSearchOutput, torch.LongTensor]:
r"""
Generates sequences for models with a language modeling head using beam search decoding.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
beam_scorer (:obj:`BeamScorer`):
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
constructed, stored and sorted during generation. For more information, the documentation of
:class:`~transformers.BeamScorer` should be read.
logits_processor (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
output_attentions (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more details.
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors
for more details.
output_scores (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
Return:
:class:`~transformers.generation_utilsBeamSearchDecoderOnlyOutput`,
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A
:obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if
``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if
``model.config.is_encoder_decoder=True``.
Examples::
>>> from transformers import (
... AutoTokenizer,
... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... BeamSearchScorer,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
... }
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList([
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
... ])
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
assert (
num_beams * batch_size == batch_beam_size
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
next_token_logits = outputs.logits[:, -1, :]
# adjust tokens for Bart, *e.g.*
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
#m = torch.nn.LayerNorm(num_beams * vocab_size)
#next_token_scores = m(next_token_scores)
next_token_scores_group = torch.sum(next_token_scores,dim=0,keepdim=True).expand(batch_size,-1) / batch_size
for i in range(next_token_scores.size(0)):
'''tmin = torch.min(next_token_scores_group[i])
for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])):
next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin'''
for t in model_kwargs['decoder_ori_input_ids'][i]:
for j in range(num_beams):
#if t not in input_ids[i] or t==1:
next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t]
next_token_scores, next_tokens = torch.topk(
next_token_scores_group, 2 * num_beams, dim=1, largest=True, sorted=True)
'''next_token_scores_group = next_token_scores_group.expand(batch_size,-1)
next_tokens_group = next_tokens_group.expand(batch_size,-1)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
for i in range(next_token_scores.size(0)):
j1 = 0
for j in range(next_token_scores.size(1)):
if next_tokens[i][j] not in model_kwargs['decoder_ori_input_ids'][i]:
next_tokens[i][j] = next_tokens_group[i][j1]
j1 += 1
next_token_scores = next_token_scores_group
del next_token_scores_group, next_tokens_group'''
next_indices = next_tokens // vocab_size
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
cur_len = cur_len + 1
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if beam_scorer.is_done:
break
sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return BeamSearchDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return sequence_outputs["sequences"]
def group_beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
**model_kwargs,
):
r"""
Generates sequences for models with a language modeling head using beam search decoding.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
beam_scorer (:obj:`BeamScorer`):
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
constructed, stored and sorted during generation. For more information, the documentation of
:class:`~transformers.BeamScorer` should be read.
logits_processor (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
output_attentions (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more details.
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors
for more details.
output_scores (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
model_kwargs:
Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
Return:
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`,
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A
:obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if
:class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if
``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a
:class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if
``model.config.is_encoder_decoder=True``.
Examples::
>>> from transformers import (
... AutoTokenizer,
... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... HammingDiversityLogitsProcessor,
... BeamSearchScorer,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run diverse beam search using 6 beams
>>> num_beams = 6
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
... }
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... num_beam_groups=3
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList([
... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
... ])
>>> outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
device = input_ids.device
batch_beam_size, cur_len = input_ids.shape
assert (
num_beams * batch_size == batch_beam_size
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))
while cur_len < max_length:
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx
# indices of beams of current group among all sentences in batch
batch_group_indices = []
if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :]).half() # .float()
for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of current group only
next_token_logits = outputs.logits[batch_group_indices, -1, :]
# adjust tokens for Bart, *e.g.*
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1]
next_token_scores = logits_processor(
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(
next_token_scores
)
if output_scores:
processed_score[batch_group_indices] = next_token_scores.half() # .float()
# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
###
next_token_scores_group = torch.sum(next_token_scores, dim=0, keepdim=True).expand(batch_size,
-1) / batch_size
for i in range(next_token_scores.size(0)):
'''tmin = torch.min(next_token_scores_group[i])
for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])):
next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin'''
for t in model_kwargs['decoder_ori_input_ids'][i]:
for j in range(group_size):
# if t not in input_ids[i] or t==1:
next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t]
next_token_scores, next_tokens = torch.topk(
next_token_scores_group, 2 * group_size, dim=1, largest=True, sorted=True)
###
#next_token_scores, next_tokens = torch.topk(
# next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
#)
next_indices = next_tokens // vocab_size
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size)
)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (processed_score,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
cur_len = cur_len + 1
if beam_scorer.is_done:
break
sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=max_length,
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"]
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return BeamSearchDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return sequence_outputs["sequences"]