import torch import torch.nn as nn import torch.nn.functional as F from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, LogitsProcessorList, MinLengthLogitsProcessor, BeamSearchScorer, StoppingCriteriaList, MaxLengthCriteria, T5ForConditionalGeneration, T5Tokenizer ) class EncoderDecoderCalibrator(nn.Module): def __init__(self, model, loss, regularization, beam_size, num_candidates, max_length=16, alpha=0.01): super().__init__() self.model = model self.loss = loss self.regularization = regularization self.alpha = alpha assert beam_size >= num_candidates, "num_candidates should be less or equal than beam_size" self.beam_size = beam_size self.num_candidates = num_candidates self.min_length = 0 self.max_length = max_length self.length_penalty = 1.0 self.eos_token_id = self.model.config.eos_token_id self.decoder_start_token_id = self.model.config.decoder_start_token_id self.pad_token_id = self.model.config.pad_token_id def generate_candidates(self, encoder_outputs): B, L = encoder_outputs.last_hidden_state.shape[:2] beam_scorer = BeamSearchScorer( batch_size=B, num_beams=self.beam_size, device=encoder_outputs.last_hidden_state.device, length_penalty=self.length_penalty, do_early_stopping=False, num_beam_hyps_to_keep=self.num_candidates, max_length=self.max_length, ) stopping_criteria = StoppingCriteriaList() stopping_criteria.append( MaxLengthCriteria( max_length=self.max_length, max_position_embeddings=self.max_length, ) ) logits_processor = LogitsProcessorList( [ MinLengthLogitsProcessor(self.min_length, eos_token_id=self.eos_token_id), ] ) encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.repeat_interleave(self.beam_size, 0) input_ids = torch.full((B * self.beam_size, 1), self.decoder_start_token_id, device=self.model.device, dtype=torch.long) # print(input_ids.shape) return self.model.beam_search( input_ids, beam_scorer, logits_processor=logits_processor, pad_token_id=self.pad_token_id, eos_token_id=self.eos_token_id, output_scores=True, output_logits=True, output_hidden_states=True, stopping_criteria=stopping_criteria, return_dict_in_generate=True, encoder_outputs=encoder_outputs ) def forward(self, input_ids, labels, **kwargs): # print(input_ids.shape) B, C, L,H = input_ids.shape # generate output of encoder encoder_outputs = self.model.get_encoder()(input_ids, return_dict=True) candidates = self.generate_candidates(encoder_outputs) sequences = candidates.sequences # print(sequences.shape) # print(B, self.num_candidates) sequences_len = (sequences != 0).sum(-1) transition_scores = self.model.compute_transition_scores(sequences, candidates.scores, candidates.beam_indices, normalize_logits=False) sequences_scores = transition_scores.sum(-1) / sequences_len loss = self.loss(sequences.view(B, self.num_candidates, -1), labels, sequences_scores.view((B, -1))) del candidates # TODO: investigate if we can use the scores returned by the beam search #scores_reg = torch.stack(candidates.scores, dim=1) scores_reg = F.log_softmax(self.model(decoder_input_ids=sequences, encoder_outputs=encoder_outputs).logits, dim=-1) loss = loss + self.alpha * self.regularization(sequences, scores_reg, labels, encoder_outputs=encoder_outputs) return {"loss": loss} # def generate(self, input_ids, max_length=None, num_return_sequences=1, **kwargs): # if max_length is None: # max_length = self.max_length # encoder_outputs = self.model.get_encoder()(input_ids, return_dict=True) # print(encoder_outputs) # output_ids = self.model.generate( # encoder_outputs=encoder_outputs, # max_length=max_length, # num_return_sequences=num_return_sequences, # do_sample=True, # Enable sampling # top_k=50, # Set the top-k sampling parameter # top_p=0.95, # Set the top-p (nucleus) sampling parameter # num_beams=4, # Set the number of beams for beam search # early_stopping=True, # Enable early stopping # **kwargs # ) # return output_ids