from typing import List, Optional, Union, Tuple import torch import torch.nn.functional as F from transformers import GPT2LMHeadModel, LogitsProcessorList, LogitsProcessor, PreTrainedTokenizer from transformers.generation_utils import GenerationMixin, SampleOutput, SampleEncoderDecoderOutput, SampleDecoderOnlyOutput class SelfDebiasingLogitsProcessor(LogitsProcessor): """This class represents a logits processor that applies self-debiasing.""" def __init__(self, num_debiasing_prefixes: int, decay_constant: float = 50, epsilon: float = 0.01, debug: bool = False, tokenizer: Optional[PreTrainedTokenizer] = None): """ :param num_debiasing_prefixes: the number of debiasing prefixes used :param decay_constant: the decay constant (lambda in the paper) :param epsilon: the minimum factor by which each probability is multiplied :param debug: whether to print additional debugging output :param tokenizer: a tokenizer used to print debugging output """ assert not debug or tokenizer, "If debug=True, a tokenizer must be passed to SelfDebiasingLogitsProcessor()" self.num_debiasing_prefixes = num_debiasing_prefixes self.decay_constant = decay_constant self.epsilon = epsilon self.debug = debug self.tokenizer = tokenizer def __call__(self, input_ids: torch.LongTensor,scores: torch.FloatTensor) -> torch.FloatTensor: batch_size = scores.shape[0] // (1 + self.num_debiasing_prefixes) regular_sentence_indices = range(batch_size) for regular_sentence_idx in regular_sentence_indices: bias_indices = self._get_bias_indices(regular_sentence_idx, batch_size) if bias_indices: self._debias_scores(scores, regular_sentence_idx, bias_indices) return scores def _get_bias_indices(self, regular_sentence_idx: int, batch_size: int) -> List[int]: """Returns the indices of all self-debiasing inputs for a regular input""" return [regular_sentence_idx + (prefix_idx + 1) * batch_size for prefix_idx in range(self.num_debiasing_prefixes)] def _debias_scores(self, scores: torch.FloatTensor, regular_sent_idx: int, bias_indices: List[int]) -> None: """Partially debiases the given scores considering a single sentence and the corresponding self-debiasing inputs""" logits_biased = [scores[bias_idx] for bias_idx in bias_indices] mask = self._generate_decay_mask(scores[regular_sent_idx], logits_biased) scores[regular_sent_idx] = torch.log(self._apply_decay_mask(scores[regular_sent_idx], mask)) for debiasing_sent_idx in bias_indices: scores[debiasing_sent_idx] = scores[regular_sent_idx] def _apply_decay_mask(self, logits: torch.Tensor, decay_mask: torch.Tensor) -> torch.Tensor: """Applies exponential decay to a tensor of logits""" probabilities = logits.softmax(dim=-1) decay_mask = torch.exp(- decay_mask * self.decay_constant) decay_mask = torch.max(decay_mask, torch.tensor([self.epsilon], device=decay_mask.device)) probabilities = probabilities * decay_mask probabilities = probabilities / probabilities.sum(dim=-1) return probabilities def _generate_decay_mask(self, logits_regular: torch.FloatTensor, logits_biased_list: List[torch.FloatTensor]) -> torch.Tensor: """Computes the alpha values (see paper) for each token and stores them in a mask tensor""" p_regular = logits_regular.softmax(dim=-1) p_biased = None for logits_biased in logits_biased_list: if p_biased is None: p_biased = logits_biased.softmax(dim=-1) else: p_biased = torch.max(p_biased, logits_biased.softmax(dim=-1)) if self.debug: print(f'== Before Debiasing ==\n' f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}\n' f'Top 5 predictions (biased): {self._get_most_likely_tokens(p_biased, k=5)}') mask = torch.max(p_biased - p_regular, torch.tensor([0.], device=p_regular.device)) if self.debug: p_regular = self._apply_decay_mask(logits_regular, mask) print(f'== After Debiasing ==\n' f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}') return mask def _get_most_likely_tokens(self, probabilities_tensor: torch.Tensor, k: int) -> List[Tuple[str, float]]: """Returns the most likely tokens according to a tensor of probabilities""" assert len(probabilities_tensor.shape) == 1 values, indices = torch.topk(probabilities_tensor, k=k, dim=-1) tokens = self.tokenizer.convert_ids_to_tokens(indices) return list(zip(tokens, [pv.item() for pv in values])) class SelfDebiasingGPT2LMHeadModel(GPT2LMHeadModel, GenerationMixin): """ This class represents a regular GPT2LMHeadModel that additionally has the capacity to perform self-debiasing. For self-debiasing, the init_logits_processor function must be called. Otherwise, this model just performs regular language modeling. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.logits_processor = None # type: Optional[SelfDebiasingLogitsProcessor] def init_logits_processor(self, *args, **kwargs): """Initialize the logits processor. For a list of arguments, see the self-debiasing logit processor's init function.""" self.logits_processor = SelfDebiasingLogitsProcessor(*args, **kwargs) def _get_logits_processor(self, *args, **kwargs) -> LogitsProcessorList: logits_processor = super()._get_logits_processor(*args, **kwargs) if self.logits_processor is not None: logits_processor.append(self.logits_processor) return logits_processor def beam_sample(self, *args, **kwargs): raise NotImplementedError("Beam sampling is not implemented for self-debiasing models") def sample(self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, **model_kwargs) -> Union[ SampleOutput, torch.LongTensor]: """ This is a verbatim copy of the original implementation by huggingface, with a single modification to ensure that a text and all corresponding self-debiasing inputs always chose the same token to generate next. This modification is enclosed by the texts "BEGIN MODIFICATIONS" and "END MODIFICATIONS", respectively. """ # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # init sequence length tensors sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation( input_ids, max_length ) # auto-regressive generation while cur_len < max_length: # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # sample probs = F.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # ========================= # BEGIN MODIFICATIONS # the following modification to the sample method is necessary to ensure that each debiasing sentence is continued in the same # way as the original sentence if self.logits_processor is not None: batch_size = next_tokens.shape[0] // (1 + self.logits_processor.num_debiasing_prefixes) regular_sentence_indices = range(batch_size) for regular_sentence_idx in regular_sentence_indices: debiasing_sentence_indices = self.logits_processor._get_bias_indices(regular_sentence_idx, batch_size) for debiasing_sentence_idx in debiasing_sentence_indices: next_tokens[debiasing_sentence_idx] = next_tokens[regular_sentence_idx] # END MODIFICATIONS # ========================= # add code that transfomers next_tokens to tokens_to_add if eos_token_id is not None: assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences) # add token and increase length by one input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) cur_len = cur_len + 1 # update sequence length if eos_token_id is not None: sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation( sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id ) # stop when there is a in each sentence, or if we exceed the maximul length if unfinished_sequences.max() == 0: break # update model kwargs model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if return_dict_in_generate: if self.config.is_encoder_decoder: return SampleEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return SampleDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return input_ids