kunwarsaaim commited on
Commit
92fd951
1 Parent(s): 3abae59

initial commit

Browse files
Files changed (3) hide show
  1. app.py +53 -0
  2. generation.py +252 -0
  3. modeling.py +257 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modeling import GPT2Wrapper
3
+ import gradio as gr
4
+
5
+
6
+ DEBIASING_PREFIXES = [
7
+ "The following text contains rude, disrespectful, or unreasonable language:\n",
8
+ "The following text contains very hateful, aggressive, disrespectful language:\n",
9
+ "The following text contains sexually explicit language:\n",
10
+ "The following text contains a threat:\n",
11
+ "The following text contains negative or hateful language targeting someone because of their identity:\n",
12
+ "The following text contains swear words, curse words, or other obscene or profane language:\n"
13
+ ]
14
+
15
+ DEBIASING_KEYWORDS = [
16
+ "(rude) ", "(sexually explicit) ", "(hateful) ", "(aggressive) ", "(racist) ", "(threat) ", "(violent) ", "(sexist) "
17
+ ]
18
+
19
+ def debias(prompt, model,use_prefix, max_length=50, num_beam=3):
20
+ """
21
+ Debiasing inference function.
22
+ :param prompt: The prompt to be debiased.
23
+ :param model: The GPT2 model.
24
+ :param max_length: The maximum length of the output sentence.
25
+ :return: The debiased output sentence.
26
+ """
27
+ wrapper = GPT2Wrapper(model_name=str(model), use_cuda=False)
28
+ if use_prefix == 'Prefixes':
29
+ debiasing_prefixes = DEBIASING_PREFIXES
30
+ else:
31
+ debiasing_prefixes = DEBIASING_KEYWORDS
32
+
33
+ output_text = output_text = wrapper.generate_self_debiasing([prompt], debiasing_prefixes= debiasing_prefixes,min_length=20, max_length=max_length, num_beam=num_beam,no_repeat_ngram_size=2)
34
+ output_text = output_text[0]
35
+
36
+ debiasing_prefixes = []
37
+ biased_text = wrapper.generate_self_debiasing([prompt], debiasing_prefixes= debiasing_prefixes,min_length=20, max_length=max_length, num_beam=num_beam,no_repeat_ngram_size=2)
38
+ biased_text = biased_text[0]
39
+ return output_text, biased_text
40
+
41
+
42
+ demo = gr.Interface(
43
+ debias,
44
+ inputs = [gr.Textbox(),
45
+ gr.Radio(choices=['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'],value='gpt2'),
46
+ gr.Radio(choices=['Prefixes','Keywords'],value='Prefixes',label='Use Debiasing Prefixes or Keywords'),
47
+ gr.Number(value=50,label='Max output length'),
48
+ gr.Number(value=3,label='Number of beams for beam search')],
49
+ outputs = [gr.Textbox(label="Debiased text"),gr.Textbox(label="Biased text")]
50
+ )
51
+ if __name__ == '__main__':
52
+
53
+ demo.launch()
generation.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import GPT2LMHeadModel, LogitsProcessorList, LogitsProcessor, PreTrainedTokenizer
6
+ from transformers.generation_utils import GenerationMixin, SampleOutput, SampleEncoderDecoderOutput, SampleDecoderOnlyOutput
7
+
8
+
9
+ class SelfDebiasingLogitsProcessor(LogitsProcessor):
10
+ """This class represents a logits processor that applies self-debiasing."""
11
+
12
+ def __init__(self, num_debiasing_prefixes: int, decay_constant: float = 50, epsilon: float = 0.01, debug: bool = False,
13
+ tokenizer: Optional[PreTrainedTokenizer] = None):
14
+ """
15
+ :param num_debiasing_prefixes: the number of debiasing prefixes used
16
+ :param decay_constant: the decay constant (lambda in the paper)
17
+ :param epsilon: the minimum factor by which each probability is multiplied
18
+ :param debug: whether to print additional debugging output
19
+ :param tokenizer: a tokenizer used to print debugging output
20
+ """
21
+ assert not debug or tokenizer, "If debug=True, a tokenizer must be passed to SelfDebiasingLogitsProcessor()"
22
+ self.num_debiasing_prefixes = num_debiasing_prefixes
23
+ self.decay_constant = decay_constant
24
+ self.epsilon = epsilon
25
+ self.debug = debug
26
+ self.tokenizer = tokenizer
27
+
28
+ def __call__(self, input_ids: torch.LongTensor,scores: torch.FloatTensor) -> torch.FloatTensor:
29
+ batch_size = scores.shape[0] // (1 + self.num_debiasing_prefixes)
30
+ regular_sentence_indices = range(batch_size)
31
+ for regular_sentence_idx in regular_sentence_indices:
32
+ bias_indices = self._get_bias_indices(regular_sentence_idx, batch_size)
33
+ if bias_indices:
34
+ self._debias_scores(scores, regular_sentence_idx, bias_indices)
35
+ return scores
36
+
37
+ def _get_bias_indices(self, regular_sentence_idx: int, batch_size: int) -> List[int]:
38
+ """Returns the indices of all self-debiasing inputs for a regular input"""
39
+ return [regular_sentence_idx + (prefix_idx + 1) * batch_size for prefix_idx in range(self.num_debiasing_prefixes)]
40
+
41
+ def _debias_scores(self, scores: torch.FloatTensor, regular_sent_idx: int, bias_indices: List[int]) -> None:
42
+ """Partially debiases the given scores considering a single sentence and the corresponding self-debiasing inputs"""
43
+ logits_biased = [scores[bias_idx] for bias_idx in bias_indices]
44
+
45
+ mask = self._generate_decay_mask(scores[regular_sent_idx], logits_biased)
46
+ scores[regular_sent_idx] = torch.log(self._apply_decay_mask(scores[regular_sent_idx], mask))
47
+
48
+ for debiasing_sent_idx in bias_indices:
49
+ scores[debiasing_sent_idx] = scores[regular_sent_idx]
50
+
51
+ def _apply_decay_mask(self, logits: torch.Tensor, decay_mask: torch.Tensor) -> torch.Tensor:
52
+ """Applies exponential decay to a tensor of logits"""
53
+ probabilities = logits.softmax(dim=-1)
54
+ decay_mask = torch.exp(- decay_mask * self.decay_constant)
55
+ decay_mask = torch.max(decay_mask, torch.tensor([self.epsilon], device=decay_mask.device))
56
+ probabilities = probabilities * decay_mask
57
+ probabilities = probabilities / probabilities.sum(dim=-1)
58
+ return probabilities
59
+
60
+ def _generate_decay_mask(self, logits_regular: torch.FloatTensor, logits_biased_list: List[torch.FloatTensor]) -> torch.Tensor:
61
+ """Computes the alpha values (see paper) for each token and stores them in a mask tensor"""
62
+ p_regular = logits_regular.softmax(dim=-1)
63
+ p_biased = None
64
+
65
+ for logits_biased in logits_biased_list:
66
+ if p_biased is None:
67
+ p_biased = logits_biased.softmax(dim=-1)
68
+ else:
69
+ p_biased = torch.max(p_biased, logits_biased.softmax(dim=-1))
70
+
71
+ if self.debug:
72
+ print(f'== Before Debiasing ==\n'
73
+ f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}\n'
74
+ f'Top 5 predictions (biased): {self._get_most_likely_tokens(p_biased, k=5)}')
75
+
76
+ mask = torch.max(p_biased - p_regular, torch.tensor([0.], device=p_regular.device))
77
+
78
+ if self.debug:
79
+ p_regular = self._apply_decay_mask(logits_regular, mask)
80
+ print(f'== After Debiasing ==\n'
81
+ f'Top 5 predictions (regular): {self._get_most_likely_tokens(p_regular, k=5)}')
82
+
83
+ return mask
84
+
85
+ def _get_most_likely_tokens(self, probabilities_tensor: torch.Tensor, k: int) -> List[Tuple[str, float]]:
86
+ """Returns the most likely tokens according to a tensor of probabilities"""
87
+ assert len(probabilities_tensor.shape) == 1
88
+ values, indices = torch.topk(probabilities_tensor, k=k, dim=-1)
89
+ tokens = self.tokenizer.convert_ids_to_tokens(indices)
90
+ return list(zip(tokens, [pv.item() for pv in values]))
91
+
92
+
93
+ class SelfDebiasingGPT2LMHeadModel(GPT2LMHeadModel, GenerationMixin):
94
+ """
95
+ This class represents a regular GPT2LMHeadModel that additionally has the capacity to perform self-debiasing. For self-debiasing, the
96
+ init_logits_processor function must be called. Otherwise, this model just performs regular language modeling.
97
+ """
98
+
99
+ def __init__(self, *args, **kwargs):
100
+ super().__init__(*args, **kwargs)
101
+ self.logits_processor = None # type: Optional[SelfDebiasingLogitsProcessor]
102
+
103
+ def init_logits_processor(self, *args, **kwargs):
104
+ """Initialize the logits processor. For a list of arguments, see the self-debiasing logit processor's init function."""
105
+ self.logits_processor = SelfDebiasingLogitsProcessor(*args, **kwargs)
106
+
107
+ def _get_logits_processor(self, *args, **kwargs) -> LogitsProcessorList:
108
+ logits_processor = super()._get_logits_processor(*args, **kwargs)
109
+ if self.logits_processor is not None:
110
+ logits_processor.append(self.logits_processor)
111
+ return logits_processor
112
+
113
+ def beam_sample(self, *args, **kwargs):
114
+ raise NotImplementedError("Beam sampling is not implemented for self-debiasing models")
115
+
116
+ def sample(self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None,
117
+ logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None,
118
+ eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
119
+ output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, **model_kwargs) -> Union[
120
+ SampleOutput, torch.LongTensor]:
121
+ """
122
+ This is a verbatim copy of the original implementation by huggingface, with a single modification to ensure that a text and all
123
+ corresponding self-debiasing inputs always chose the same token to generate next. This modification is enclosed by the texts
124
+ "BEGIN MODIFICATIONS" and "END MODIFICATIONS", respectively.
125
+ """
126
+ # init values
127
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
128
+ logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
129
+ max_length = max_length if max_length is not None else self.config.max_length
130
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
131
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
132
+ output_scores = output_scores if output_scores is not None else self.config.output_scores
133
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
134
+ output_hidden_states = (
135
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
136
+ )
137
+ return_dict_in_generate = (
138
+ return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
139
+ )
140
+
141
+ # init attention / hidden states / scores tuples
142
+ scores = () if (return_dict_in_generate and output_scores) else None
143
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
144
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
145
+
146
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
147
+ if return_dict_in_generate and self.config.is_encoder_decoder:
148
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
149
+ encoder_hidden_states = (
150
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
151
+ )
152
+
153
+ # init sequence length tensors
154
+ sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
155
+ input_ids, max_length
156
+ )
157
+
158
+ # auto-regressive generation
159
+ while cur_len < max_length:
160
+ # prepare model inputs
161
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
162
+
163
+ # forward pass to get next token
164
+ outputs = self(
165
+ **model_inputs,
166
+ return_dict=True,
167
+ output_attentions=output_attentions,
168
+ output_hidden_states=output_hidden_states,
169
+ )
170
+ next_token_logits = outputs.logits[:, -1, :]
171
+
172
+ # pre-process distribution
173
+ next_token_scores = logits_processor(input_ids, next_token_logits)
174
+ next_token_scores = logits_warper(input_ids, next_token_scores)
175
+
176
+ # Store scores, attentions and hidden_states when required
177
+ if return_dict_in_generate:
178
+ if output_scores:
179
+ scores += (next_token_scores,)
180
+ if output_attentions:
181
+ decoder_attentions += (
182
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
183
+ )
184
+
185
+ if output_hidden_states:
186
+ decoder_hidden_states += (
187
+ (outputs.decoder_hidden_states,)
188
+ if self.config.is_encoder_decoder
189
+ else (outputs.hidden_states,)
190
+ )
191
+
192
+ # sample
193
+ probs = F.softmax(next_token_scores, dim=-1)
194
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
195
+
196
+ # =========================
197
+ # BEGIN MODIFICATIONS
198
+ # the following modification to the sample method is necessary to ensure that each debiasing sentence is continued in the same
199
+ # way as the original sentence
200
+ if self.logits_processor is not None:
201
+ batch_size = next_tokens.shape[0] // (1 + self.logits_processor.num_debiasing_prefixes)
202
+ regular_sentence_indices = range(batch_size)
203
+ for regular_sentence_idx in regular_sentence_indices:
204
+ debiasing_sentence_indices = self.logits_processor._get_bias_indices(regular_sentence_idx, batch_size)
205
+ for debiasing_sentence_idx in debiasing_sentence_indices:
206
+ next_tokens[debiasing_sentence_idx] = next_tokens[regular_sentence_idx]
207
+ # END MODIFICATIONS
208
+ # =========================
209
+
210
+ # add code that transfomers next_tokens to tokens_to_add
211
+ if eos_token_id is not None:
212
+ assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
213
+ next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences)
214
+
215
+ # add token and increase length by one
216
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
217
+ cur_len = cur_len + 1
218
+
219
+ # update sequence length
220
+ if eos_token_id is not None:
221
+ sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
222
+ sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
223
+ )
224
+
225
+ # stop when there is a </s> in each sentence, or if we exceed the maximul length
226
+ if unfinished_sequences.max() == 0:
227
+ break
228
+
229
+ # update model kwargs
230
+ model_kwargs = self._update_model_kwargs_for_generation(
231
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
232
+ )
233
+
234
+ if return_dict_in_generate:
235
+ if self.config.is_encoder_decoder:
236
+ return SampleEncoderDecoderOutput(
237
+ sequences=input_ids,
238
+ scores=scores,
239
+ encoder_attentions=encoder_attentions,
240
+ encoder_hidden_states=encoder_hidden_states,
241
+ decoder_attentions=decoder_attentions,
242
+ decoder_hidden_states=decoder_hidden_states,
243
+ )
244
+ else:
245
+ return SampleDecoderOnlyOutput(
246
+ sequences=input_ids,
247
+ scores=scores,
248
+ attentions=decoder_attentions,
249
+ hidden_states=decoder_hidden_states,
250
+ )
251
+ else:
252
+ return input_ids
modeling.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from abc import ABC, abstractmethod
3
+ from typing import List, Optional, Tuple
4
+
5
+ import torch
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import T5Tokenizer, T5ForConditionalGeneration, GPT2Tokenizer, PreTrainedTokenizer, PreTrainedModel
8
+
9
+ from generation import SelfDebiasingGPT2LMHeadModel
10
+
11
+
12
+ class ModelWrapper(ABC):
13
+ """
14
+ This class represents a wrapper for a pretrained language model that provides some high-level functions, including zero-shot
15
+ classification using cloze questions and the generation of texts with self-debiasing.
16
+ """
17
+
18
+ def __init__(self, use_cuda: bool = True):
19
+ """
20
+ :param use_cuda: whether to use CUDA
21
+ """
22
+ self._device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
23
+ self._tokenizer = None # type: Optional[PreTrainedTokenizer]
24
+ self._model = None # type: Optional[PreTrainedModel]
25
+
26
+ def query_model(self, input_text: str) -> torch.FloatTensor:
27
+ """For a given input text, returns the probability distribution over possible next tokens."""
28
+ return self.query_model_batch([input_text])[0]
29
+
30
+ @abstractmethod
31
+ def query_model_batch(self, input_texts: List[str]) -> torch.FloatTensor:
32
+ """For a batch of input texts, returns the probability distribution over possible next tokens."""
33
+ pass
34
+
35
+ @abstractmethod
36
+ def generate(self, input_text: str, **kwargs) -> str:
37
+ """Generates a continuation for a given input text."""
38
+ pass
39
+
40
+ @abstractmethod
41
+ def generate_self_debiasing(self, input_texts: List[str], debiasing_prefixes: List[str], decay_constant: float = 50,
42
+ epsilon: float = 0.01, debug: bool = False, **kwargs) -> List[str]:
43
+ """
44
+ Generates continuations for the given input texts with self-debiasing.
45
+ :param input_texts: the input texts to generate continuations for
46
+ :param debiasing_prefixes: the debiasing prefixes to be used
47
+ :param decay_constant: the decay constant (lambda in the paper)
48
+ :param epsilon: the minimum factor by which each probability is multiplied
49
+ :param debug: whether to print additional debugging output
50
+ :param kwargs: further arguments are passed on to the original generate function
51
+ :return: the list of generated continuations
52
+ """
53
+ pass
54
+
55
+ @abstractmethod
56
+ def compute_loss(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
57
+ """Computes cross-entropy loss for the given input ids and corresponding labels."""
58
+ pass
59
+
60
+ @abstractmethod
61
+ def compute_loss_self_debiasing(self, input_ids: torch.Tensor, trg_len: int, debiasing_prefixes: List[str], decay_constant: float = 50,
62
+ epsilon: float = 0.01, debug: bool = False) -> torch.Tensor:
63
+ """
64
+ Computes cross-entropy loss for the given input ids with self-debiasing.
65
+ :param input_ids: the input ids
66
+ :param trg_len: only the last trg_len tokens are considered for computing the loss
67
+ :param debiasing_prefixes: the debiasing prefixes to be used
68
+ :param decay_constant: the decay constant (lambda in the paper)
69
+ :param epsilon: the minimum factor by which each probability is multiplied
70
+ :param debug: whether to print additional debugging output
71
+ :return: the cross entropy loss
72
+ """
73
+ pass
74
+
75
+ def get_token_probability_distribution(self, input_texts: List[str], output_choices: List[str]) -> List[List[Tuple[str, float]]]:
76
+ """
77
+ For a batch of input texts, returns the probability distribution over possible next tokens considering only the given list of
78
+ output choices.
79
+ :param input_texts: the input texts
80
+ :param output_choices: the allowed output choices (must correspond to single tokens in the model's vocabulary)
81
+ :return: a list of lists, where output[i][j] is a (output, probability) tuple for the ith input and jth output choice.
82
+ """
83
+ output_choice_ids = []
84
+ kwargs = {'add_prefix_space': True} if isinstance(self, GPT2Wrapper) else {}
85
+ for word in output_choices:
86
+ tokens = self._tokenizer.tokenize(word, **kwargs)
87
+ assert len(tokens) == 1, f"Word {word} consists of multiple tokens: {tokens}"
88
+ assert tokens[0] not in self._tokenizer.all_special_tokens, f"Word {word} corresponds to a special token: {tokens[0]}"
89
+ token_id = self._tokenizer.convert_tokens_to_ids(tokens)[0]
90
+ output_choice_ids.append(token_id)
91
+
92
+ logits = self.query_model_batch(input_texts)
93
+ result = []
94
+
95
+ for idx, _ in enumerate(input_texts):
96
+ output_probabilities = logits[idx][output_choice_ids].softmax(dim=0)
97
+ choices_with_probabilities = list(zip(output_choices, (prob.item() for prob in output_probabilities)))
98
+ result.append(choices_with_probabilities)
99
+
100
+ return result
101
+
102
+
103
+ class T5Wrapper(ModelWrapper):
104
+ """A wrapper for the T5 model"""
105
+
106
+ def __init__(self, model_name: str = "google/t5-v1_1-xl", use_cuda: bool = True):
107
+ """
108
+ :param model_name: the name of the pretrained T5 model (default: "google/t5-v1_1-xl")
109
+ :param use_cuda: whether to use CUDA
110
+ """
111
+ super().__init__(use_cuda=use_cuda)
112
+ self._tokenizer = T5Tokenizer.from_pretrained(model_name)
113
+ self._model = T5ForConditionalGeneration.from_pretrained(model_name)
114
+ if use_cuda:
115
+ self._model.parallelize()
116
+
117
+ def query_model_batch(self, input_texts: List[str]):
118
+ assert all('<extra_id_0>' in input_text for input_text in input_texts)
119
+ output_texts = ['<extra_id_0>'] * len(input_texts)
120
+ inputs = self._tokenizer.batch_encode_plus(input_texts, padding=True, return_tensors='pt')
121
+ inputs = {key: val.to(self._device) for key, val in inputs.items()}
122
+ output_ids = self._tokenizer.batch_encode_plus(output_texts, return_tensors='pt')['input_ids'].to(self._device)
123
+ return self._model(labels=output_ids, **inputs)['logits'][:, 1, :]
124
+
125
+ def generate(self, input_text: str, **kwargs):
126
+ assert '<extra_id_0>' in input_text
127
+ input_ids = self._tokenizer.encode(input_text, return_tensors='pt').to(self._device)
128
+ output_ids = self._model.generate(input_ids, **kwargs)[0]
129
+ return self._tokenizer.decode(output_ids)
130
+
131
+ def generate_self_debiasing(self, input_texts: List[str], debiasing_prefixes: List[str], decay_constant: float = 50,
132
+ epsilon: float = 0.01, debug: bool = False, **kwargs) -> List[str]:
133
+ raise NotImplementedError()
134
+
135
+ def compute_loss(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
136
+ raise NotImplementedError()
137
+
138
+ def compute_loss_self_debiasing(self, input_ids: torch.Tensor, trg_len: int, debiasing_prefixes: List[str], decay_constant: float = 50,
139
+ epsilon: float = 0.01, debug: bool = False) -> torch.Tensor:
140
+ raise NotImplementedError()
141
+
142
+
143
+ class GPT2Wrapper(ModelWrapper):
144
+
145
+ def __init__(self, model_name: str = "gpt2-xl", use_cuda: bool = True):
146
+ """
147
+ :param model_name: the name of the pretrained GPT2 model (default: "gpt2-xl")
148
+ :param use_cuda: whether to use CUDA
149
+ """
150
+ super().__init__(use_cuda=use_cuda)
151
+ self._tokenizer = GPT2Tokenizer.from_pretrained(model_name)
152
+ self._model = SelfDebiasingGPT2LMHeadModel.from_pretrained(model_name) # type: SelfDebiasingGPT2LMHeadModel
153
+ if use_cuda:
154
+ self._model.parallelize()
155
+ self._tokenizer.pad_token = self._tokenizer.eos_token
156
+ self._model.config.pad_token_id = self._tokenizer.eos_token_id
157
+
158
+ def query_model_batch(self, input_texts: List[str]):
159
+ inputs = self._tokenizer.batch_encode_plus(input_texts, padding=True, return_tensors='pt')
160
+ inputs = {key: val.to(self._device) for key, val in inputs.items()}
161
+ output_indices = inputs['attention_mask'].sum(dim=1) - 1
162
+ output = self._model(**inputs)['logits']
163
+ return torch.stack([output[example_idx, last_word_idx, :] for example_idx, last_word_idx in enumerate(output_indices)])
164
+
165
+ def generate(self, input_text: str, **kwargs):
166
+ input_ids = self._tokenizer.encode(input_text, return_tensors='pt').to(self._device)
167
+ output_ids = self._model.generate(input_ids, **kwargs)[0]
168
+ return self._tokenizer.decode(output_ids)
169
+
170
+ def generate_self_debiasing(self, input_texts: List[str], debiasing_prefixes: List[str], decay_constant: float = 50,
171
+ epsilon: float = 0.01, debug: bool = False, min_length: int = None, max_length: int = None,
172
+ **kwargs) -> List[str]:
173
+
174
+ self._model.init_logits_processor(num_debiasing_prefixes=len(debiasing_prefixes), decay_constant=decay_constant, epsilon=epsilon,
175
+ debug=debug, tokenizer=self._tokenizer)
176
+ inputs = input_texts.copy()
177
+ for debiasing_prefix in debiasing_prefixes:
178
+ for input_text in input_texts:
179
+ inputs += [debiasing_prefix + input_text]
180
+
181
+ inputs = self._tokenizer.batch_encode_plus(inputs, padding=True, return_tensors='pt')
182
+ inputs['attention_mask'] = torch.flip(inputs['attention_mask'], dims=[1])
183
+ shifts = inputs['attention_mask'].shape[-1] - inputs['attention_mask'].sum(dim=-1)
184
+ for batch_idx in range(inputs['input_ids'].shape[0]):
185
+ inputs['input_ids'][batch_idx] = inputs['input_ids'][batch_idx].roll(shifts[batch_idx].item())
186
+
187
+ inputs = {k: v.to(self._device) for k, v in inputs.items()}
188
+ input_length = inputs['input_ids'].shape[1]
189
+ if min_length is not None:
190
+ min_length = min_length + input_length
191
+ if max_length is not None:
192
+ max_length = max_length + input_length
193
+
194
+ output_ids = self._model.generate(**inputs, min_length=min_length, max_length=max_length, **kwargs)
195
+
196
+ batch_size = output_ids.shape[0] // (1 + len(debiasing_prefixes))
197
+ output_ids = output_ids[:batch_size, inputs['input_ids'].shape[1]:]
198
+ return self._tokenizer.batch_decode(output_ids)
199
+
200
+ def compute_loss(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
201
+ outputs = self._model(input_ids, labels=labels)
202
+ lm_logits = outputs[1]
203
+
204
+ # Shift so that tokens < n predict n
205
+ shift_logits = lm_logits[..., :-1, :].contiguous()
206
+ shift_labels = labels[..., 1:].contiguous()
207
+ # Flatten the tokens
208
+ loss_fct = CrossEntropyLoss()
209
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
210
+ return loss
211
+
212
+ def compute_loss_self_debiasing(self, input_ids: torch.Tensor, trg_len: int, debiasing_prefixes: List[str], decay_constant: float = 50,
213
+ epsilon: float = 0.01, debug: bool = False) -> torch.Tensor:
214
+
215
+ self._model.init_logits_processor(num_debiasing_prefixes=len(debiasing_prefixes), decay_constant=decay_constant, epsilon=epsilon,
216
+ debug=debug, tokenizer=self._tokenizer)
217
+
218
+ input_prefixes = [''] + debiasing_prefixes
219
+ input_prefixes = self._tokenizer.batch_encode_plus(input_prefixes, padding=True, return_tensors='pt')
220
+ input_prefixes['attention_mask'] = torch.flip(input_prefixes['attention_mask'], dims=[1])
221
+
222
+ shifts = input_prefixes['attention_mask'].shape[-1] - input_prefixes['attention_mask'].sum(dim=-1)
223
+ for batch_idx in range(input_prefixes['input_ids'].shape[0]):
224
+ input_prefixes['input_ids'][batch_idx] = input_prefixes['input_ids'][batch_idx].roll(shifts[batch_idx].item())
225
+
226
+ input_prefixes = {k: v.to(self._device) for k, v in input_prefixes.items()}
227
+
228
+ input_ids_repeated = input_ids.repeat(len(debiasing_prefixes) + 1, 1)
229
+ attention_mask = torch.ones_like(input_ids_repeated)
230
+
231
+ attention_mask = torch.cat([input_prefixes['attention_mask'], attention_mask], dim=-1)
232
+ input_ids_repeated = torch.cat([input_prefixes['input_ids'], input_ids_repeated], dim=-1)
233
+
234
+ target_ids = input_ids_repeated.clone()
235
+ trg_len += shifts[0]
236
+ target_ids[:, :-trg_len] = -100
237
+
238
+ position_ids = attention_mask.long().cumsum(-1) - 1
239
+ position_ids.masked_fill_(attention_mask == 0, 1)
240
+
241
+ outputs = self._model(input_ids=input_ids_repeated, attention_mask=attention_mask, position_ids=position_ids, labels=target_ids)
242
+ lm_logits = outputs[1]
243
+
244
+ for idx in range(lm_logits.shape[1]):
245
+ lm_logits[:, idx, :] = self._model.logits_processor(input_ids=None, scores=lm_logits[:, idx, :])
246
+
247
+ batch_size = lm_logits.shape[0] // (1 + len(debiasing_prefixes))
248
+ lm_logits = lm_logits[:batch_size, shifts[0]:, :]
249
+ target_ids = target_ids[:batch_size, shifts[0]:]
250
+
251
+ # Shift so that tokens < n predict n
252
+ shift_logits = lm_logits[..., :-1, :].contiguous()
253
+ shift_labels = target_ids[..., 1:].contiguous()
254
+ # Flatten the tokens
255
+ loss_fct = CrossEntropyLoss()
256
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
257
+ return loss