Fabrice-TIERCELIN commited on
Commit
50b940f
1 Parent(s): 1b04590

Rewrite old function from modeling_bloom.py

Browse files
llava/model/language_model/mpt/hf_prefixlm_converter.py CHANGED
@@ -1,415 +1,441 @@
1
- """Converts Huggingface Causal LM to Prefix LM.
2
-
3
- Conversion does lightweight surgery on a HuggingFace
4
- Causal LM to convert it to a Prefix LM.
5
-
6
- Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
- and treat the input prompt as the prefix in `generate`.
8
- """
9
- import math
10
- import warnings
11
- from types import MethodType
12
- from typing import Any, Dict, List, Optional, Tuple, Union
13
- import torch
14
- from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
15
- from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
16
- from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
17
- from transformers.models.bloom.modeling_bloom import logging
18
- from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
19
- from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
20
- from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
21
- from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
22
- from transformers.models.opt.modeling_opt import OPTForCausalLM
23
- from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
24
- from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
25
- logger = logging.get_logger(__name__)
26
- _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
27
- CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
28
-
29
- def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
30
- """Converts a GPT-style Causal LM to a Prefix LM.
31
-
32
- Supported HuggingFace model classes:
33
- - `GPT2LMHeadModel`
34
- - `GPTNeoForCausalLM`
35
- - `GPTNeoXForCausalLM`
36
- - `GPTJForCausalLM`
37
-
38
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
39
- """
40
- if hasattr(model, '_prefix_lm_converted'):
41
- return model
42
- assert isinstance(model, _SUPPORTED_GPT_MODELS)
43
- assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models'
44
-
45
- def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
46
- """Helper that gets a list of the model's attention modules.
47
-
48
- Each module has a `bias` buffer used for causal masking. The Prefix LM
49
- conversion adds logic to dynamically manipulate these biases to support
50
- Prefix LM attention masking.
51
- """
52
- attn_modules = []
53
- if isinstance(model, GPTNeoXForCausalLM):
54
- blocks = model.gpt_neox.layers
55
- else:
56
- blocks = model.transformer.h
57
- for block in blocks:
58
- if isinstance(model, GPTNeoForCausalLM):
59
- if block.attn.attention_type != 'global':
60
- continue
61
- attn_module = block.attn.attention
62
- elif isinstance(model, GPTNeoXForCausalLM):
63
- attn_module = block.attention
64
- else:
65
- attn_module = block.attn
66
- attn_modules.append(attn_module)
67
- return attn_modules
68
- setattr(model, '_original_forward', getattr(model, 'forward'))
69
- setattr(model, '_original_generate', getattr(model, 'generate'))
70
-
71
- def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
72
- """Wraps original forward to enable PrefixLM attention."""
73
-
74
- def call_og_forward():
75
- if isinstance(self, GPTNeoXForCausalLM):
76
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
77
- else:
78
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
79
- if bidirectional_mask is None:
80
- return call_og_forward()
81
- assert isinstance(bidirectional_mask, torch.Tensor)
82
- attn_modules = _get_attn_modules(model)
83
- (b, s) = bidirectional_mask.shape
84
- max_length = attn_modules[0].bias.shape[-1]
85
- if s > max_length:
86
- raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).')
87
- assert s <= max_length
88
- if s < max_length:
89
- pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
90
- bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
91
- bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
92
- for attn_module in attn_modules:
93
- attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
94
- output = call_og_forward()
95
- for attn_module in attn_modules:
96
- attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
97
- return output
98
-
99
- def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
100
- """Wraps original generate to enable PrefixLM attention."""
101
- attn_modules = _get_attn_modules(model)
102
- for attn_module in attn_modules:
103
- attn_module.bias.data[:] = 1
104
- output = self._original_generate(*args, **kwargs)
105
- for attn_module in attn_modules:
106
- attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
107
- return output
108
- setattr(model, 'forward', MethodType(forward, model))
109
- setattr(model, 'generate', MethodType(generate, model))
110
- setattr(model, '_prefix_lm_converted', True)
111
- return model
112
-
113
- def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
114
- """Converts a BLOOM Causal LM to a Prefix LM.
115
-
116
- Supported HuggingFace model classes:
117
- - `BloomForCausalLM`
118
-
119
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
120
- """
121
- if hasattr(model, '_prefix_lm_converted'):
122
- return model
123
- assert isinstance(model, BloomForCausalLM)
124
- assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'
125
-
126
- def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
127
- combined_attention_mask = None
128
- device = attention_mask.device
129
- (_, src_length) = input_shape
130
- if src_length > 1:
131
- combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
132
- if bidirectional_mask is not None:
133
- assert attention_mask.shape == bidirectional_mask.shape
134
- expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
135
- combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
136
- expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
137
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
138
- return combined_attention_mask
139
-
140
- def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
141
- num_heads = self.config.n_head
142
- closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
143
- base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
144
- powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
145
- slopes = torch.pow(base, powers)
146
- if closest_power_of_2 != num_heads:
147
- extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32)
148
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
149
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
150
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
151
- qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
152
- ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
153
- diffs = qa - ka + key_length - query_length
154
- diffs = -diffs.abs()
155
- alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
156
- alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
157
- return alibi.to(dtype)
158
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
159
-
160
- def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
161
- if deprecated_arguments.pop('position_ids', False) is not False:
162
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
163
- if len(deprecated_arguments) > 0:
164
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
165
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
166
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
167
- use_cache = use_cache if use_cache is not None else self.config.use_cache
168
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
- if input_ids is not None and inputs_embeds is not None:
170
- raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
171
- elif input_ids is not None:
172
- (batch_size, seq_length) = input_ids.shape
173
- elif inputs_embeds is not None:
174
- (batch_size, seq_length, _) = inputs_embeds.shape
175
- else:
176
- raise ValueError('You have to specify either input_ids or inputs_embeds')
177
- if past_key_values is None:
178
- past_key_values = tuple([None] * len(self.h))
179
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
180
- if inputs_embeds is None:
181
- inputs_embeds = self.word_embeddings(input_ids)
182
- hidden_states = self.word_embeddings_layernorm(inputs_embeds)
183
- presents = () if use_cache else None
184
- all_self_attentions = () if output_attentions else None
185
- all_hidden_states = () if output_hidden_states else None
186
- seq_length_with_past = seq_length
187
- past_key_values_length = 0
188
- if past_key_values[0] is not None:
189
- tmp = past_key_values[0][0]
190
- past_key_values_length = tmp.shape[2]
191
- seq_length_with_past = seq_length_with_past + past_key_values_length
192
- if attention_mask is None:
193
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
194
- else:
195
- attention_mask = attention_mask.to(hidden_states.device)
196
- alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device)
197
- causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length)
198
- for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
199
- if output_hidden_states:
200
- hst = (hidden_states,)
201
- all_hidden_states = all_hidden_states + hst
202
- if self.gradient_checkpointing and self.training:
203
- if use_cache:
204
- logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
205
- use_cache = False
206
-
207
- def create_custom_forward(module):
208
-
209
- def custom_forward(*inputs):
210
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
211
- return custom_forward
212
- outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
213
- else:
214
- outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)
215
- hidden_states = outputs[0]
216
- if use_cache is True:
217
- presents = presents + (outputs[1],)
218
- if output_attentions:
219
- oa = (outputs[2 if use_cache else 1],)
220
- all_self_attentions = all_self_attentions + oa
221
- hidden_states = self.ln_f(hidden_states)
222
- if output_hidden_states:
223
- hst = (hidden_states,)
224
- all_hidden_states = all_hidden_states + hst
225
- if not return_dict:
226
- return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
227
- return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
228
- setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
229
- setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
230
- setattr(model.transformer, 'forward', MethodType(forward, model.transformer))
231
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
232
-
233
- def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
234
- """Replacement forward method for BloomCausalLM."""
235
- if deprecated_arguments.pop('position_ids', False) is not False:
236
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
237
- if len(deprecated_arguments) > 0:
238
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
239
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
240
- transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
241
- hidden_states = transformer_outputs[0]
242
- lm_logits = self.lm_head(hidden_states)
243
- loss = None
244
- if labels is not None:
245
- shift_logits = lm_logits[..., :-1, :].contiguous()
246
- shift_labels = labels[..., 1:].contiguous()
247
- (batch_size, seq_length, vocab_size) = shift_logits.shape
248
- loss_fct = CrossEntropyLoss()
249
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
250
- if not return_dict:
251
- output = (lm_logits,) + transformer_outputs[1:]
252
- return (loss,) + output if loss is not None else output
253
- return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
254
-
255
- def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict:
256
- if past:
257
- input_ids = input_ids[:, -1].unsqueeze(-1)
258
- bidirectional_mask = None
259
- if past[0][0].shape[0] == input_ids.shape[0]:
260
- past = self._convert_to_bloom_cache(past)
261
- else:
262
- bidirectional_mask = torch.ones_like(input_ids)
263
- return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask}
264
- setattr(model, 'forward', MethodType(forward, model))
265
- setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model))
266
- setattr(model, '_prefix_lm_converted', True)
267
- return model
268
-
269
- def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
270
- """Converts an OPT Causal LM to a Prefix LM.
271
-
272
- Supported HuggingFace model classes:
273
- - `OPTForCausalLM`
274
-
275
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
276
- """
277
- if hasattr(model, '_prefix_lm_converted'):
278
- return model
279
- assert isinstance(model, OPTForCausalLM)
280
- assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'
281
- setattr(model, '_original_forward', getattr(model, 'forward'))
282
- setattr(model, '_original_generate', getattr(model, 'generate'))
283
- model.model.decoder.bidirectional_mask = None
284
-
285
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
286
- combined_attention_mask = None
287
- if input_shape[-1] > 1:
288
- if self.bidirectional_mask == 'g':
289
- (bsz, src_length) = input_shape
290
- combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
291
- else:
292
- combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
293
- if self.bidirectional_mask is not None:
294
- assert attention_mask.shape == self.bidirectional_mask.shape
295
- expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
296
- combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
297
- if attention_mask is not None:
298
- expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
299
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
300
- return combined_attention_mask
301
- setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
302
-
303
- def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
304
-
305
- def call_og_forward():
306
- return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
307
- if bidirectional_mask is None:
308
- return call_og_forward()
309
- self.model.decoder.bidirectional_mask = bidirectional_mask
310
- try:
311
- outputs = call_og_forward()
312
- except:
313
- self.model.decoder.bidirectional_mask = None
314
- raise
315
- self.model.decoder.bidirectional_mask = None
316
- return outputs
317
-
318
- def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
319
- """Wraps original generate to enable PrefixLM-style attention."""
320
- self.model.decoder.bidirectional_mask = 'g'
321
- try:
322
- output = self._original_generate(*args, **kwargs)
323
- except:
324
- self.model.decoder.bidirectional_mask = None
325
- raise
326
- self.model.decoder.bidirectional_mask = None
327
- return output
328
- setattr(model, 'forward', MethodType(forward, model))
329
- setattr(model, 'generate', MethodType(generate, model))
330
- setattr(model, '_prefix_lm_converted', True)
331
- return model
332
- _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
333
- CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
334
-
335
- def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
336
- """Converts a HuggingFace Causal LM to a Prefix LM.
337
-
338
- Supported HuggingFace model classes:
339
- - `GPT2LMHeadModel`
340
- - `GPTNeoForCausalLM`
341
- - `GPTNeoXForCausalLM`
342
- - `GPTJForCausalLM`
343
- - `BloomForCausalLM`
344
- - `OPTForCausalLM`
345
-
346
- Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
347
- `generate` method and/or select underlying methods depending on the model class.
348
-
349
- These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
350
-
351
- Notes on training:
352
- To actually train the converted model as a Prefix LM, training batches will need to indicate
353
- the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
354
-
355
- **This is not a standard input and requires custom layers either within or after your dataloader.**
356
-
357
- In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
358
- such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
359
- That is, the prefix portion of the sequence should not generate any loss. Loss should only be
360
- generated by the target portion of the sequence.
361
-
362
- Notes on `GPTNeoForCausalLM`:
363
- To simplify the implementation, "global" and "local" attention layers are handled differently.
364
- For "global" layers, we handle conversion as described above. For "local" layers, which use a
365
- causal attention mask within a restricted local window, we do not alter the masking.
366
-
367
- Notes on `forward` method conversion:
368
- After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
369
- which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
370
- belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
371
- 0 indicates token positions belonging to the target.
372
-
373
- The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
374
- causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
375
- the causal masks before returning the result.
376
-
377
- Notes on `generate` method conversion:
378
- After conversion, the `generate` method will have the same signature but will internally
379
- convert all causal masks to be purely bidirectional, call the original `generate` method, and
380
- (where appropriate) reset the causal masks before returning the result.
381
-
382
- This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
383
- "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
384
- each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
385
- another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
386
- previously-generated tokens (also as expected in a Prefix LM).
387
-
388
- To preserve the API, the original methods are renamed to `_original_forward` and
389
- `_original_generate`, and replaced with new `forward` and `generate` methods that wrap
390
- them, respectively. Although implementation details vary by model class.
391
- """
392
- if isinstance(model, _SUPPORTED_GPT_MODELS):
393
- return _convert_gpt_causal_lm_to_prefix_lm(model)
394
- elif isinstance(model, BloomForCausalLM):
395
- return _convert_bloom_causal_lm_to_prefix_lm(model)
396
- elif isinstance(model, OPTForCausalLM):
397
- return _convert_opt_causal_lm_to_prefix_lm(model)
398
- else:
399
- raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
400
-
401
- def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
402
- """Attempts to add bidirectional_mask to batch if missing.
403
-
404
- Raises:
405
- KeyError if bidirectional_mask is missing and can't be inferred
406
- """
407
- if 'bidirectional_mask' not in batch:
408
- if batch.get('mode', None) == 'icl_task':
409
- batch['bidirectional_mask'] = batch['attention_mask'].clone()
410
- for (i, continuation_indices) in enumerate(batch['continuation_indices']):
411
- batch['bidirectional_mask'][i, continuation_indices] = 0
412
- elif 'labels' in batch and 'attention_mask' in batch:
413
- batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask'])
414
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  raise KeyError('No bidirectional_mask in batch and not sure how to construct one.')
 
1
+ """Converts Huggingface Causal LM to Prefix LM.
2
+
3
+ Conversion does lightweight surgery on a HuggingFace
4
+ Causal LM to convert it to a Prefix LM.
5
+
6
+ Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
+ and treat the input prompt as the prefix in `generate`.
8
+ """
9
+ import math
10
+ import warnings
11
+ from types import MethodType
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+ import torch
14
+ from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
15
+ from transformers.models.bloom.modeling_bloom import logging
16
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
17
+ from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
18
+ from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
19
+ from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
20
+ from transformers.models.opt.modeling_opt import OPTForCausalLM
21
+ from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
22
+ from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
23
+ logger = logging.get_logger(__name__)
24
+ _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
25
+ CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
26
+
27
+ def _make_causal_mask_bloom(
28
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
29
+ ) -> torch.BoolTensor:
30
+ """
31
+ Make causal mask used for self-attention.
32
+ """
33
+ batch_size, target_length = input_ids_shape
34
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
35
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
36
+ seq_ids = torch.arange(target_length, device=device)
37
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
38
+
39
+ if past_key_values_length > 0:
40
+ mask[:, :past_key_values_length] = False
41
+
42
+ expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
43
+ return expanded_mask
44
+
45
+ def _expand_mask_bloom(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
46
+ """
47
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
48
+ """
49
+ batch_size, src_length = mask.shape
50
+ tgt_length = tgt_length if tgt_length is not None else src_length
51
+
52
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
53
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
54
+
55
+ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
56
+ """Converts a GPT-style Causal LM to a Prefix LM.
57
+
58
+ Supported HuggingFace model classes:
59
+ - `GPT2LMHeadModel`
60
+ - `GPTNeoForCausalLM`
61
+ - `GPTNeoXForCausalLM`
62
+ - `GPTJForCausalLM`
63
+
64
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
65
+ """
66
+ if hasattr(model, '_prefix_lm_converted'):
67
+ return model
68
+ assert isinstance(model, _SUPPORTED_GPT_MODELS)
69
+ assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models'
70
+
71
+ def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
72
+ """Helper that gets a list of the model's attention modules.
73
+
74
+ Each module has a `bias` buffer used for causal masking. The Prefix LM
75
+ conversion adds logic to dynamically manipulate these biases to support
76
+ Prefix LM attention masking.
77
+ """
78
+ attn_modules = []
79
+ if isinstance(model, GPTNeoXForCausalLM):
80
+ blocks = model.gpt_neox.layers
81
+ else:
82
+ blocks = model.transformer.h
83
+ for block in blocks:
84
+ if isinstance(model, GPTNeoForCausalLM):
85
+ if block.attn.attention_type != 'global':
86
+ continue
87
+ attn_module = block.attn.attention
88
+ elif isinstance(model, GPTNeoXForCausalLM):
89
+ attn_module = block.attention
90
+ else:
91
+ attn_module = block.attn
92
+ attn_modules.append(attn_module)
93
+ return attn_modules
94
+ setattr(model, '_original_forward', getattr(model, 'forward'))
95
+ setattr(model, '_original_generate', getattr(model, 'generate'))
96
+
97
+ def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
98
+ """Wraps original forward to enable PrefixLM attention."""
99
+
100
+ def call_og_forward():
101
+ if isinstance(self, GPTNeoXForCausalLM):
102
+ return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
103
+ else:
104
+ return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
105
+ if bidirectional_mask is None:
106
+ return call_og_forward()
107
+ assert isinstance(bidirectional_mask, torch.Tensor)
108
+ attn_modules = _get_attn_modules(model)
109
+ (b, s) = bidirectional_mask.shape
110
+ max_length = attn_modules[0].bias.shape[-1]
111
+ if s > max_length:
112
+ raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).')
113
+ assert s <= max_length
114
+ if s < max_length:
115
+ pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
116
+ bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
117
+ bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
118
+ for attn_module in attn_modules:
119
+ attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
120
+ output = call_og_forward()
121
+ for attn_module in attn_modules:
122
+ attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
123
+ return output
124
+
125
+ def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
126
+ """Wraps original generate to enable PrefixLM attention."""
127
+ attn_modules = _get_attn_modules(model)
128
+ for attn_module in attn_modules:
129
+ attn_module.bias.data[:] = 1
130
+ output = self._original_generate(*args, **kwargs)
131
+ for attn_module in attn_modules:
132
+ attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
133
+ return output
134
+ setattr(model, 'forward', MethodType(forward, model))
135
+ setattr(model, 'generate', MethodType(generate, model))
136
+ setattr(model, '_prefix_lm_converted', True)
137
+ return model
138
+
139
+ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
140
+ """Converts a BLOOM Causal LM to a Prefix LM.
141
+
142
+ Supported HuggingFace model classes:
143
+ - `BloomForCausalLM`
144
+
145
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
146
+ """
147
+ if hasattr(model, '_prefix_lm_converted'):
148
+ return model
149
+ assert isinstance(model, BloomForCausalLM)
150
+ assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'
151
+
152
+ def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
153
+ combined_attention_mask = None
154
+ device = attention_mask.device
155
+ (_, src_length) = input_shape
156
+ if src_length > 1:
157
+ combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
158
+ if bidirectional_mask is not None:
159
+ assert attention_mask.shape == bidirectional_mask.shape
160
+ expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
161
+ combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
162
+ expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
163
+ combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
164
+ return combined_attention_mask
165
+
166
+ def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
167
+ num_heads = self.config.n_head
168
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
169
+ base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
170
+ powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
171
+ slopes = torch.pow(base, powers)
172
+ if closest_power_of_2 != num_heads:
173
+ extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32)
174
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
175
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
176
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
177
+ qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
178
+ ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
179
+ diffs = qa - ka + key_length - query_length
180
+ diffs = -diffs.abs()
181
+ alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
182
+ alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
183
+ return alibi.to(dtype)
184
+ KeyValueT = Tuple[torch.Tensor, torch.Tensor]
185
+
186
+ def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
187
+ if deprecated_arguments.pop('position_ids', False) is not False:
188
+ warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
189
+ if len(deprecated_arguments) > 0:
190
+ raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
191
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
192
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
193
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
195
+ if input_ids is not None and inputs_embeds is not None:
196
+ raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
197
+ elif input_ids is not None:
198
+ (batch_size, seq_length) = input_ids.shape
199
+ elif inputs_embeds is not None:
200
+ (batch_size, seq_length, _) = inputs_embeds.shape
201
+ else:
202
+ raise ValueError('You have to specify either input_ids or inputs_embeds')
203
+ if past_key_values is None:
204
+ past_key_values = tuple([None] * len(self.h))
205
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.word_embeddings(input_ids)
208
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
209
+ presents = () if use_cache else None
210
+ all_self_attentions = () if output_attentions else None
211
+ all_hidden_states = () if output_hidden_states else None
212
+ seq_length_with_past = seq_length
213
+ past_key_values_length = 0
214
+ if past_key_values[0] is not None:
215
+ tmp = past_key_values[0][0]
216
+ past_key_values_length = tmp.shape[2]
217
+ seq_length_with_past = seq_length_with_past + past_key_values_length
218
+ if attention_mask is None:
219
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
220
+ else:
221
+ attention_mask = attention_mask.to(hidden_states.device)
222
+ alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device)
223
+ causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length)
224
+ for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
225
+ if output_hidden_states:
226
+ hst = (hidden_states,)
227
+ all_hidden_states = all_hidden_states + hst
228
+ if self.gradient_checkpointing and self.training:
229
+ if use_cache:
230
+ logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
231
+ use_cache = False
232
+
233
+ def create_custom_forward(module):
234
+
235
+ def custom_forward(*inputs):
236
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
237
+ return custom_forward
238
+ outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
239
+ else:
240
+ outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)
241
+ hidden_states = outputs[0]
242
+ if use_cache is True:
243
+ presents = presents + (outputs[1],)
244
+ if output_attentions:
245
+ oa = (outputs[2 if use_cache else 1],)
246
+ all_self_attentions = all_self_attentions + oa
247
+ hidden_states = self.ln_f(hidden_states)
248
+ if output_hidden_states:
249
+ hst = (hidden_states,)
250
+ all_hidden_states = all_hidden_states + hst
251
+ if not return_dict:
252
+ return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
253
+ return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
254
+ setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
255
+ setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
256
+ setattr(model.transformer, 'forward', MethodType(forward, model.transformer))
257
+ KeyValueT = Tuple[torch.Tensor, torch.Tensor]
258
+
259
+ def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
260
+ """Replacement forward method for BloomCausalLM."""
261
+ if deprecated_arguments.pop('position_ids', False) is not False:
262
+ warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
263
+ if len(deprecated_arguments) > 0:
264
+ raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
265
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
266
+ transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
267
+ hidden_states = transformer_outputs[0]
268
+ lm_logits = self.lm_head(hidden_states)
269
+ loss = None
270
+ if labels is not None:
271
+ shift_logits = lm_logits[..., :-1, :].contiguous()
272
+ shift_labels = labels[..., 1:].contiguous()
273
+ (batch_size, seq_length, vocab_size) = shift_logits.shape
274
+ loss_fct = CrossEntropyLoss()
275
+ loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
276
+ if not return_dict:
277
+ output = (lm_logits,) + transformer_outputs[1:]
278
+ return (loss,) + output if loss is not None else output
279
+ return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
280
+
281
+ def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict:
282
+ if past:
283
+ input_ids = input_ids[:, -1].unsqueeze(-1)
284
+ bidirectional_mask = None
285
+ if past[0][0].shape[0] == input_ids.shape[0]:
286
+ past = self._convert_to_bloom_cache(past)
287
+ else:
288
+ bidirectional_mask = torch.ones_like(input_ids)
289
+ return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask}
290
+ setattr(model, 'forward', MethodType(forward, model))
291
+ setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model))
292
+ setattr(model, '_prefix_lm_converted', True)
293
+ return model
294
+
295
+ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
296
+ """Converts an OPT Causal LM to a Prefix LM.
297
+
298
+ Supported HuggingFace model classes:
299
+ - `OPTForCausalLM`
300
+
301
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
302
+ """
303
+ if hasattr(model, '_prefix_lm_converted'):
304
+ return model
305
+ assert isinstance(model, OPTForCausalLM)
306
+ assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'
307
+ setattr(model, '_original_forward', getattr(model, 'forward'))
308
+ setattr(model, '_original_generate', getattr(model, 'generate'))
309
+ model.model.decoder.bidirectional_mask = None
310
+
311
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
312
+ combined_attention_mask = None
313
+ if input_shape[-1] > 1:
314
+ if self.bidirectional_mask == 'g':
315
+ (bsz, src_length) = input_shape
316
+ combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
317
+ else:
318
+ combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
319
+ if self.bidirectional_mask is not None:
320
+ assert attention_mask.shape == self.bidirectional_mask.shape
321
+ expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
322
+ combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
323
+ if attention_mask is not None:
324
+ expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
325
+ combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
326
+ return combined_attention_mask
327
+ setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
328
+
329
+ def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
330
+
331
+ def call_og_forward():
332
+ return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
333
+ if bidirectional_mask is None:
334
+ return call_og_forward()
335
+ self.model.decoder.bidirectional_mask = bidirectional_mask
336
+ try:
337
+ outputs = call_og_forward()
338
+ except:
339
+ self.model.decoder.bidirectional_mask = None
340
+ raise
341
+ self.model.decoder.bidirectional_mask = None
342
+ return outputs
343
+
344
+ def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
345
+ """Wraps original generate to enable PrefixLM-style attention."""
346
+ self.model.decoder.bidirectional_mask = 'g'
347
+ try:
348
+ output = self._original_generate(*args, **kwargs)
349
+ except:
350
+ self.model.decoder.bidirectional_mask = None
351
+ raise
352
+ self.model.decoder.bidirectional_mask = None
353
+ return output
354
+ setattr(model, 'forward', MethodType(forward, model))
355
+ setattr(model, 'generate', MethodType(generate, model))
356
+ setattr(model, '_prefix_lm_converted', True)
357
+ return model
358
+ _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
359
+ CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
360
+
361
+ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
362
+ """Converts a HuggingFace Causal LM to a Prefix LM.
363
+
364
+ Supported HuggingFace model classes:
365
+ - `GPT2LMHeadModel`
366
+ - `GPTNeoForCausalLM`
367
+ - `GPTNeoXForCausalLM`
368
+ - `GPTJForCausalLM`
369
+ - `BloomForCausalLM`
370
+ - `OPTForCausalLM`
371
+
372
+ Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
373
+ `generate` method and/or select underlying methods depending on the model class.
374
+
375
+ These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
376
+
377
+ Notes on training:
378
+ To actually train the converted model as a Prefix LM, training batches will need to indicate
379
+ the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
380
+
381
+ **This is not a standard input and requires custom layers either within or after your dataloader.**
382
+
383
+ In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
384
+ such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
385
+ That is, the prefix portion of the sequence should not generate any loss. Loss should only be
386
+ generated by the target portion of the sequence.
387
+
388
+ Notes on `GPTNeoForCausalLM`:
389
+ To simplify the implementation, "global" and "local" attention layers are handled differently.
390
+ For "global" layers, we handle conversion as described above. For "local" layers, which use a
391
+ causal attention mask within a restricted local window, we do not alter the masking.
392
+
393
+ Notes on `forward` method conversion:
394
+ After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
395
+ which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
396
+ belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
397
+ 0 indicates token positions belonging to the target.
398
+
399
+ The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
400
+ causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
401
+ the causal masks before returning the result.
402
+
403
+ Notes on `generate` method conversion:
404
+ After conversion, the `generate` method will have the same signature but will internally
405
+ convert all causal masks to be purely bidirectional, call the original `generate` method, and
406
+ (where appropriate) reset the causal masks before returning the result.
407
+
408
+ This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
409
+ "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
410
+ each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
411
+ another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
412
+ previously-generated tokens (also as expected in a Prefix LM).
413
+
414
+ To preserve the API, the original methods are renamed to `_original_forward` and
415
+ `_original_generate`, and replaced with new `forward` and `generate` methods that wrap
416
+ them, respectively. Although implementation details vary by model class.
417
+ """
418
+ if isinstance(model, _SUPPORTED_GPT_MODELS):
419
+ return _convert_gpt_causal_lm_to_prefix_lm(model)
420
+ elif isinstance(model, BloomForCausalLM):
421
+ return _convert_bloom_causal_lm_to_prefix_lm(model)
422
+ elif isinstance(model, OPTForCausalLM):
423
+ return _convert_opt_causal_lm_to_prefix_lm(model)
424
+ else:
425
+ raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
426
+
427
+ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
428
+ """Attempts to add bidirectional_mask to batch if missing.
429
+
430
+ Raises:
431
+ KeyError if bidirectional_mask is missing and can't be inferred
432
+ """
433
+ if 'bidirectional_mask' not in batch:
434
+ if batch.get('mode', None) == 'icl_task':
435
+ batch['bidirectional_mask'] = batch['attention_mask'].clone()
436
+ for (i, continuation_indices) in enumerate(batch['continuation_indices']):
437
+ batch['bidirectional_mask'][i, continuation_indices] = 0
438
+ elif 'labels' in batch and 'attention_mask' in batch:
439
+ batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask'])
440
+ else:
441
  raise KeyError('No bidirectional_mask in batch and not sure how to construct one.')