from transformers.models.bart import BartForConditionalGeneration import torch from transformers.generation_beam_search import BeamScorer from abc import ABC, abstractmethod from collections import UserDict from typing import Optional, Tuple, Union, Dict, Any from transformers.generation_logits_process import LogitsProcessorList from transformers.generation_utils import BeamSearchEncoderDecoderOutput,BeamSearchDecoderOnlyOutput from torch.nn import functional as F from transformers.file_utils import ModelOutput import torch.nn BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] class BartForConditionalGeneration_GroupBeam(BartForConditionalGeneration): def beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, **model_kwargs, ) -> Union[BeamSearchOutput, torch.LongTensor]: r""" Generates sequences for models with a language modeling head using beam search decoding. Parameters: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. beam_scorer (:obj:`BeamScorer`): An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of :class:`~transformers.BeamScorer` should be read. logits_processor (:obj:`LogitsProcessorList`, `optional`): An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling head applied at each generation step. max_length (:obj:`int`, `optional`, defaults to 20): The maximum length of the sequence to be generated. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. output_attentions (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more details. output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors for more details. output_scores (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: :class:`~transformers.generation_utilsBeamSearchDecoderOnlyOutput`, :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if ``model.config.is_encoder_decoder=True``. Examples:: >>> from transformers import ( ... AutoTokenizer, ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... BeamSearchScorer, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") >>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large") >>> encoder_input_str = "translate English to German: How old are you?" >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids >>> # lets run beam search using 3 beams >>> num_beams = 3 >>> # define decoder start token ids >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) >>> input_ids = input_ids * model.config.decoder_start_token_id >>> # add encoder_outputs to model keyword arguments >>> model_kwargs = { ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) ... } >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, ... max_length=model.config.max_length, ... num_beams=num_beams, ... device=model.device, ... ) >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList([ ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), ... ]) >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) """ # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape assert ( num_beams * batch_size == batch_beam_size ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) next_token_logits = outputs.logits[:, -1, :] # adjust tokens for Bart, *e.g.* next_token_logits = self.adjust_logits_during_generation( next_token_logits, cur_len=cur_len, max_length=max_length ) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) #m = torch.nn.LayerNorm(num_beams * vocab_size) #next_token_scores = m(next_token_scores) next_token_scores_group = torch.sum(next_token_scores,dim=0,keepdim=True).expand(batch_size,-1) / batch_size for i in range(next_token_scores.size(0)): '''tmin = torch.min(next_token_scores_group[i]) for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])): next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin''' for t in model_kwargs['decoder_ori_input_ids'][i]: for j in range(num_beams): #if t not in input_ids[i] or t==1: next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t] next_token_scores, next_tokens = torch.topk( next_token_scores_group, 2 * num_beams, dim=1, largest=True, sorted=True) '''next_token_scores_group = next_token_scores_group.expand(batch_size,-1) next_tokens_group = next_tokens_group.expand(batch_size,-1) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True ) for i in range(next_token_scores.size(0)): j1 = 0 for j in range(next_token_scores.size(1)): if next_tokens[i][j] not in model_kwargs['decoder_ori_input_ids'][i]: next_tokens[i][j] = next_tokens_group[i][j1] j1 += 1 next_token_scores = next_token_scores_group del next_token_scores_group, next_tokens_group''' next_indices = next_tokens // vocab_size next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) cur_len = cur_len + 1 model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past"] is not None: model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) if beam_scorer.is_done: break sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return BeamSearchDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return sequence_outputs["sequences"] def group_beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, **model_kwargs, ): r""" Generates sequences for models with a language modeling head using beam search decoding. Parameters: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. beam_scorer (:obj:`BeamScorer`): An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of :class:`~transformers.BeamScorer` should be read. logits_processor (:obj:`LogitsProcessorList`, `optional`): An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling head applied at each generation step. max_length (:obj:`int`, `optional`, defaults to 20): The maximum length of the sequence to be generated. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. output_attentions (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more details. output_hidden_states (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return trhe hidden states of all layers. See ``hidden_states`` under returned tensors for more details. output_scores (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. model_kwargs: Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. Return: :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput`, :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` or obj:`torch.LongTensor`: A :obj:`torch.LongTensor` containing the generated tokens (default behaviour) or a :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if :class:`~transformers.generation_utils.BeamSearchDecoderOnlyOutput` if ``model.config.is_encoder_decoder=False`` and ``return_dict_in_generate=True`` or a :class:`~transformers.generation_utils.BeamSearchEncoderDecoderOutput` if ``model.config.is_encoder_decoder=True``. Examples:: >>> from transformers import ( ... AutoTokenizer, ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... HammingDiversityLogitsProcessor, ... BeamSearchScorer, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") >>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large") >>> encoder_input_str = "translate English to German: How old are you?" >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids >>> # lets run diverse beam search using 6 beams >>> num_beams = 6 >>> # define decoder start token ids >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) >>> input_ids = input_ids * model.config.decoder_start_token_id >>> # add encoder_outputs to model keyword arguments >>> model_kwargs = { ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) ... } >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, ... max_length=model.config.max_length, ... num_beams=num_beams, ... device=model.device, ... num_beam_groups=3 ... ) >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList([ ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), ... ]) >>> outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) """ # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams num_beam_groups = beam_scorer.num_beam_groups num_sub_beams = num_beams // num_beam_groups device = input_ids.device batch_beam_size, cur_len = input_ids.shape assert ( num_beams * batch_size == batch_beam_size ), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in # the same group don't produce same tokens everytime. beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) while cur_len < max_length: # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) # indices which will form the beams in the next time step reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) # do one decoder step on all beams of all sentences in batch model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams group_end_idx = min(group_start_idx + num_sub_beams, num_beams) group_size = group_end_idx - group_start_idx # indices of beams of current group among all sentences in batch batch_group_indices = [] if output_scores: processed_score = torch.zeros_like(outputs.logits[:, -1, :]).half() # .float() for batch_idx in range(batch_size): batch_group_indices.extend( [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] ) group_input_ids = input_ids[batch_group_indices] # select outputs of beams of current group only next_token_logits = outputs.logits[batch_group_indices, -1, :] # adjust tokens for Bart, *e.g.* next_token_logits = self.adjust_logits_during_generation( next_token_logits, cur_len=cur_len, max_length=max_length ) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) vocab_size = next_token_scores.shape[-1] next_token_scores = logits_processor( group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx ) next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as( next_token_scores ) if output_scores: processed_score[batch_group_indices] = next_token_scores.half() # .float() # reshape for beam search next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) ### next_token_scores_group = torch.sum(next_token_scores, dim=0, keepdim=True).expand(batch_size, -1) / batch_size for i in range(next_token_scores.size(0)): '''tmin = torch.min(next_token_scores_group[i]) for j in range(1,len(model_kwargs['decoder_ori_input_ids'][i])): next_token_scores_group[i][model_kwargs['decoder_ori_input_ids'][i][j]] = tmin''' for t in model_kwargs['decoder_ori_input_ids'][i]: for j in range(group_size): # if t not in input_ids[i] or t==1: next_token_scores_group[i][j * vocab_size + t] = next_token_scores[i][j * vocab_size + t] next_token_scores, next_tokens = torch.topk( next_token_scores_group, 2 * group_size, dim=1, largest=True, sorted=True) ### #next_token_scores, next_tokens = torch.topk( # next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True #) next_indices = next_tokens // vocab_size next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( group_input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids[batch_group_indices] = group_input_ids[beam_idx] group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) current_tokens[batch_group_indices] = group_input_ids[:, -1] # (beam_idx // group_size) -> batch_idx # (beam_idx % group_size) -> offset of idx inside the group reordering_indices[batch_group_indices] = ( num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size) ) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (processed_score,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past"] is not None: model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices) input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) cur_len = cur_len + 1 if beam_scorer.is_done: break sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=max_length, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return BeamSearchDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return sequence_outputs["sequences"]