Self-Debiasing / generation.py
kunwarsaaim's picture
initial commit
92fd951
from typing import List, Optional, Union, Tuple
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, LogitsProcessorList, LogitsProcessor, PreTrainedTokenizer
from transformers.generation_utils import GenerationMixin, SampleOutput, SampleEncoderDecoderOutput, SampleDecoderOnlyOutput
class SelfDebiasingLogitsProcessor(LogitsProcessor):
"""This class represents a logits processor that applies self-debiasing."""
def __init__(self, num_debiasing_prefixes: int, decay_constant: float = 50, epsilon: float = 0.01, debug: bool = False,
tokenizer: Optional[PreTrainedTokenizer] = None):
"""
:param num_debiasing_prefixes: the number of debiasing prefixes used
:param decay_constant: the decay constant (lambda in the paper)
:param epsilon: the minimum factor by which each probability is multiplied
:param debug: whether to print additional debugging output
:param tokenizer: a tokenizer used to print debugging output
"""
assert not debug or tokenizer, "If debug=True, a tokenizer must be passed to SelfDebiasingLogitsProcessor()"
self.num_debiasing_prefixes = num_debiasing_prefixes
self.decay_constant = decay_constant
self.epsilon = epsilon
self.debug = debug
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor,scores: torch.FloatTensor) -> torch.FloatTensor:
batch_size = scores.shape[0] // (1 + self.num_debiasing_prefixes)
regular_sentence_indices = range(batch_size)
for regular_sentence_idx in regular_sentence_indices:
bias_indices = self._get_bias_indices(regular_sentence_idx, batch_size)
if bias_indices:
self._debias_scores(scores, regular_sentence_idx, bias_indices)
return scores
def _get_bias_indices(self, regular_sentence_idx: int, batch_size: int) -> List[int]:
"""Returns the indices of all self-debiasing inputs for a regular input"""
return [regular_sentence_idx + (prefix_idx + 1) * batch_size for prefix_idx in range(self.num_debiasing_prefixes)]
def _debias_scores(self, scores: torch.FloatTensor, regular_sent_idx: int, bias_indices: List[int]) -> None:
"""Partially debiases the given scores considering a single sentence and the corresponding self-debiasing inputs"""
logits_biased = [scores[bias_idx] for bias_idx in bias_indices]
mask = self._generate_decay_mask(scores[regular_sent_idx], logits_biased)
scores[regular_sent_idx] = torch.log(self._apply_decay_mask(scores[regular_sent_idx], mask))
for debiasing_sent_idx in bias_indices:
scores[debiasing_sent_idx] = scores[regular_sent_idx]
def _apply_decay_mask(self, logits: torch.Tensor, decay_mask: torch.Tensor) -> torch.Tensor:
"""Applies exponential decay to a tensor of logits"""
probabilities = logits.softmax(dim=-1)
decay_mask = torch.exp(- decay_mask * self.decay_constant)
decay_mask = torch.max(decay_mask, torch.tensor([self.epsilon], device=decay_mask.device))
probabilities = probabilities * decay_mask
probabilities = probabilities / probabilities.sum(dim=-1)
return probabilities
def _generate_decay_mask(self, logits_regular: torch.FloatTensor, logits_biased_list: List[torch.FloatTensor]) -> torch.Tensor:
"""Computes the alpha values (see paper) for each token and stores them in a mask tensor"""
p_regular = logits_regular.softmax(dim=-1)
p_biased = None
for logits_biased in logits_biased_list:
if p_biased is None:
p_biased = logits_biased.softmax(dim=-1)
else:
p_biased = torch.max(p_biased, logits_biased.softmax(dim=-1))
if self.debug:
print(f'== Before Debiasing ==\n'
f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}\n'
f'Top 5 predictions (biased): {self._get_most_likely_tokens(p_biased, k=5)}')
mask = torch.max(p_biased - p_regular, torch.tensor([0.], device=p_regular.device))
if self.debug:
p_regular = self._apply_decay_mask(logits_regular, mask)
print(f'== After Debiasing ==\n'
f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}')
return mask
def _get_most_likely_tokens(self, probabilities_tensor: torch.Tensor, k: int) -> List[Tuple[str, float]]:
"""Returns the most likely tokens according to a tensor of probabilities"""
assert len(probabilities_tensor.shape) == 1
values, indices = torch.topk(probabilities_tensor, k=k, dim=-1)
tokens = self.tokenizer.convert_ids_to_tokens(indices)
return list(zip(tokens, [pv.item() for pv in values]))
class SelfDebiasingGPT2LMHeadModel(GPT2LMHeadModel, GenerationMixin):
"""
This class represents a regular GPT2LMHeadModel that additionally has the capacity to perform self-debiasing. For self-debiasing, the
init_logits_processor function must be called. Otherwise, this model just performs regular language modeling.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logits_processor = None # type: Optional[SelfDebiasingLogitsProcessor]
def init_logits_processor(self, *args, **kwargs):
"""Initialize the logits processor. For a list of arguments, see the self-debiasing logit processor's init function."""
self.logits_processor = SelfDebiasingLogitsProcessor(*args, **kwargs)
def _get_logits_processor(self, *args, **kwargs) -> LogitsProcessorList:
logits_processor = super()._get_logits_processor(*args, **kwargs)
if self.logits_processor is not None:
logits_processor.append(self.logits_processor)
return logits_processor
def beam_sample(self, *args, **kwargs):
raise NotImplementedError("Beam sampling is not implemented for self-debiasing models")
def sample(self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, **model_kwargs) -> Union[
SampleOutput, torch.LongTensor]:
"""
This is a verbatim copy of the original implementation by huggingface, with a single modification to ensure that a text and all
corresponding self-debiasing inputs always chose the same token to generate next. This modification is enclosed by the texts
"BEGIN MODIFICATIONS" and "END MODIFICATIONS", respectively.
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# init sequence length tensors
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
input_ids, max_length
)
# auto-regressive generation
while cur_len < max_length:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# sample
probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# =========================
# BEGIN MODIFICATIONS
# the following modification to the sample method is necessary to ensure that each debiasing sentence is continued in the same
# way as the original sentence
if self.logits_processor is not None:
batch_size = next_tokens.shape[0] // (1 + self.logits_processor.num_debiasing_prefixes)
regular_sentence_indices = range(batch_size)
for regular_sentence_idx in regular_sentence_indices:
debiasing_sentence_indices = self.logits_processor._get_bias_indices(regular_sentence_idx, batch_size)
for debiasing_sentence_idx in debiasing_sentence_indices:
next_tokens[debiasing_sentence_idx] = next_tokens[regular_sentence_idx]
# END MODIFICATIONS
# =========================
# add code that transfomers next_tokens to tokens_to_add
if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)
# add token and increase length by one
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
cur_len = cur_len + 1
# update sequence length
if eos_token_id is not None:
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
)
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sequences.max() == 0:
break
# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return SampleEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return SampleDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return input_ids