xianbin commited on
Commit
c7bcf32
1 Parent(s): b4203f4

Update codes to be in line with LLM-foundry update on October 30, 2023

Browse files

Should fix issues with regards to the refactoring of the attention mask codes on the transformers library

Files changed (1) hide show
  1. hf_prefixlm_converter.py +99 -263
hf_prefixlm_converter.py CHANGED
@@ -6,25 +6,24 @@ Causal LM to convert it to a Prefix LM.
6
  Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
  and treat the input prompt as the prefix in `generate`.
8
  """
9
- import math
10
- import warnings
11
  from types import MethodType
12
  from typing import Any, List, MutableMapping, Optional, Tuple, Union
13
  import torch
14
- from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
15
- from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
16
- from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
17
- from transformers.models.bloom.modeling_bloom import logging
18
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
19
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
20
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
21
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
22
- from transformers.models.opt.modeling_opt import OPTForCausalLM
23
- from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
24
- from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
25
- logger = logging.get_logger(__name__)
26
- _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
27
- CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
 
 
 
 
 
28
 
29
  def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
30
  """Converts a GPT-style Causal LM to a Prefix LM.
@@ -37,10 +36,12 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
37
 
38
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
39
  """
40
- if hasattr(model, '_prefix_lm_converted'):
41
  return model
42
  assert isinstance(model, _SUPPORTED_GPT_MODELS)
43
- assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models'
 
 
44
 
45
  def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
46
  """Helper that gets a list of the model's attention modules.
@@ -56,7 +57,7 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
56
  blocks = model.transformer.h
57
  for block in blocks:
58
  if isinstance(model, GPTNeoForCausalLM):
59
- if block.attn.attention_type != 'global':
60
  continue
61
  attn_module = block.attn.attention
62
  elif isinstance(model, GPTNeoXForCausalLM):
@@ -65,17 +66,58 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
65
  attn_module = block.attn
66
  attn_modules.append(attn_module)
67
  return attn_modules
68
- setattr(model, '_original_forward', getattr(model, 'forward'))
69
- setattr(model, '_original_generate', getattr(model, 'generate'))
70
 
71
- def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  """Wraps original forward to enable PrefixLM attention."""
73
 
74
  def call_og_forward():
75
  if isinstance(self, GPTNeoXForCausalLM):
76
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
77
  else:
78
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if bidirectional_mask is None:
80
  return call_og_forward()
81
  assert isinstance(bidirectional_mask, torch.Tensor)
@@ -83,15 +125,24 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
83
  (b, s) = bidirectional_mask.shape
84
  max_length = attn_modules[0].bias.shape[-1]
85
  if s > max_length:
86
- raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).')
 
 
 
87
  assert s <= max_length
88
  if s < max_length:
89
- pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
 
 
 
 
90
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
91
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
92
  for attn_module in attn_modules:
93
  assert isinstance(attn_module.bias, torch.Tensor)
94
- attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
 
 
95
  output = call_og_forward()
96
  for attn_module in attn_modules:
97
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
@@ -106,236 +157,18 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
106
  for attn_module in attn_modules:
107
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
108
  return output
109
- setattr(model, 'forward', MethodType(forward, model))
110
- setattr(model, 'generate', MethodType(generate, model))
111
- setattr(model, '_prefix_lm_converted', True)
112
- return model
113
-
114
- def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
115
- """Converts a BLOOM Causal LM to a Prefix LM.
116
-
117
- Supported HuggingFace model classes:
118
- - `BloomForCausalLM`
119
 
120
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
121
- """
122
- if hasattr(model, '_prefix_lm_converted'):
123
- return model
124
- assert isinstance(model, BloomForCausalLM)
125
- assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'
126
-
127
- def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
128
- combined_attention_mask = None
129
- device = attention_mask.device
130
- (_, src_length) = input_shape
131
- if src_length > 1:
132
- combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
133
- if bidirectional_mask is not None:
134
- assert attention_mask.shape == bidirectional_mask.shape
135
- expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
136
- combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
137
- expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
138
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
139
- return combined_attention_mask
140
-
141
- def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
142
- num_heads = self.config.n_head
143
- closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
144
- base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
145
- powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
146
- slopes = torch.pow(base, powers)
147
- if closest_power_of_2 != num_heads:
148
- extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32)
149
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
150
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
151
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
152
- qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
153
- ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
154
- diffs = qa - ka + key_length - query_length
155
- diffs = -diffs.abs()
156
- alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
157
- alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
158
- return alibi.to(dtype)
159
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
160
-
161
- def transformer_forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments: Any) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
162
- if deprecated_arguments.pop('position_ids', False) is not False:
163
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
164
- if len(deprecated_arguments) > 0:
165
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
166
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
167
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
168
- use_cache = use_cache if use_cache is not None else self.config.use_cache
169
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
170
- if input_ids is not None and inputs_embeds is not None:
171
- raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
172
- elif input_ids is not None:
173
- (batch_size, seq_length) = input_ids.shape
174
- elif inputs_embeds is not None:
175
- (batch_size, seq_length, _) = inputs_embeds.shape
176
- else:
177
- raise ValueError('You have to specify either input_ids or inputs_embeds')
178
- if past_key_values is None:
179
- past_key_values = tuple([None] * len(self.h))
180
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
181
- if inputs_embeds is None:
182
- inputs_embeds = self.word_embeddings(input_ids)
183
- hidden_states = self.word_embeddings_layernorm(inputs_embeds)
184
- presents = () if use_cache else None
185
- all_self_attentions = () if output_attentions else None
186
- all_hidden_states = () if output_hidden_states else None
187
- seq_length_with_past = seq_length
188
- past_key_values_length = 0
189
- if past_key_values[0] is not None:
190
- tmp = past_key_values[0][0]
191
- past_key_values_length = tmp.shape[2]
192
- seq_length_with_past = seq_length_with_past + past_key_values_length
193
- if attention_mask is None:
194
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
195
- else:
196
- attention_mask = attention_mask.to(hidden_states.device)
197
- alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device)
198
- causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length)
199
- for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
200
- if output_hidden_states:
201
- hst = (hidden_states,)
202
- all_hidden_states = all_hidden_states + hst
203
- if self.gradient_checkpointing and self.training:
204
- if use_cache:
205
- logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
206
- use_cache = False
207
-
208
- def create_custom_forward(module: torch.nn.Module):
209
-
210
- def custom_forward(*inputs: Any):
211
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
212
- return custom_forward
213
- outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
214
- else:
215
- outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)
216
- hidden_states = outputs[0]
217
- if use_cache is True:
218
- presents = presents + (outputs[1],)
219
- if output_attentions:
220
- oa = (outputs[2 if use_cache else 1],)
221
- all_self_attentions = all_self_attentions + oa
222
- hidden_states = self.ln_f(hidden_states)
223
- if output_hidden_states:
224
- hst = (hidden_states,)
225
- all_hidden_states = all_hidden_states + hst
226
- if not return_dict:
227
- return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
228
- return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
229
- setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
230
- setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
231
- setattr(model.transformer, 'forward', MethodType(transformer_forward, model.transformer))
232
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
233
-
234
- def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments: Any) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
235
- """Replacement forward method for BloomCausalLM."""
236
- if deprecated_arguments.pop('position_ids', False) is not False:
237
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
238
- if len(deprecated_arguments) > 0:
239
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
240
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
241
- transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
242
- hidden_states = transformer_outputs[0]
243
- lm_logits = self.lm_head(hidden_states)
244
- loss = None
245
- if labels is not None:
246
- shift_logits = lm_logits[..., :-1, :].contiguous()
247
- shift_labels = labels[..., 1:].contiguous()
248
- (batch_size, seq_length, vocab_size) = shift_logits.shape
249
- loss_fct = CrossEntropyLoss()
250
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
251
- if not return_dict:
252
- output = (lm_logits,) + transformer_outputs[1:]
253
- return (loss,) + output if loss is not None else output
254
- return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
255
-
256
- def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs: Any) -> dict:
257
- del kwargs
258
- if past:
259
- input_ids = input_ids[:, -1].unsqueeze(-1)
260
- bidirectional_mask = None
261
- if past[0][0].shape[0] == input_ids.shape[0]:
262
- past = self._convert_to_bloom_cache(past)
263
- else:
264
- bidirectional_mask = torch.ones_like(input_ids)
265
- return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask}
266
- setattr(model, 'forward', MethodType(forward, model))
267
- setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model))
268
- setattr(model, '_prefix_lm_converted', True)
269
  return model
270
 
271
- def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
272
- """Converts an OPT Causal LM to a Prefix LM.
273
 
274
- Supported HuggingFace model classes:
275
- - `OPTForCausalLM`
 
 
276
 
277
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
278
- """
279
- if hasattr(model, '_prefix_lm_converted'):
280
- return model
281
- assert isinstance(model, OPTForCausalLM)
282
- assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'
283
- setattr(model, '_original_forward', getattr(model, 'forward'))
284
- setattr(model, '_original_generate', getattr(model, 'generate'))
285
- model.model.decoder.bidirectional_mask = None
286
-
287
- def _prepare_decoder_attention_mask(self: torch.nn.Module, attention_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], inputs_embeds: Optional[torch.Tensor], past_key_values_length: int):
288
- combined_attention_mask = None
289
- if input_shape[-1] > 1:
290
- assert inputs_embeds is not None
291
- if self.bidirectional_mask == 'g':
292
- (bsz, src_length) = input_shape
293
- combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
294
- else:
295
- combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
296
- if self.bidirectional_mask is not None:
297
- assert attention_mask is not None
298
- assert attention_mask.shape == self.bidirectional_mask.shape
299
- expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
300
- combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
301
- if attention_mask is not None:
302
- assert inputs_embeds is not None
303
- expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
304
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
305
- return combined_attention_mask
306
- setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
307
-
308
- def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
309
-
310
- def call_og_forward():
311
- return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
312
- if bidirectional_mask is None:
313
- return call_og_forward()
314
- self.model.decoder.bidirectional_mask = bidirectional_mask
315
- try:
316
- outputs = call_og_forward()
317
- except:
318
- self.model.decoder.bidirectional_mask = None
319
- raise
320
- self.model.decoder.bidirectional_mask = None
321
- return outputs
322
-
323
- def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Any):
324
- """Wraps original generate to enable PrefixLM-style attention."""
325
- self.model.decoder.bidirectional_mask = 'g'
326
- try:
327
- output = self._original_generate(*args, **kwargs)
328
- except:
329
- self.model.decoder.bidirectional_mask = None
330
- raise
331
- self.model.decoder.bidirectional_mask = None
332
- return output
333
- setattr(model, 'forward', MethodType(forward, model))
334
- setattr(model, 'generate', MethodType(generate, model))
335
- setattr(model, '_prefix_lm_converted', True)
336
- return model
337
- _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
338
- CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
339
 
340
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
341
  """Converts a HuggingFace Causal LM to a Prefix LM.
@@ -345,8 +178,6 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
345
  - `GPTNeoForCausalLM`
346
  - `GPTNeoXForCausalLM`
347
  - `GPTJForCausalLM`
348
- - `BloomForCausalLM`
349
- - `OPTForCausalLM`
350
 
351
  Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
352
  `generate` method and/or select underlying methods depending on the model class.
@@ -396,12 +227,13 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
396
  """
397
  if isinstance(model, _SUPPORTED_GPT_MODELS):
398
  return _convert_gpt_causal_lm_to_prefix_lm(model)
399
- elif isinstance(model, BloomForCausalLM):
400
- return _convert_bloom_causal_lm_to_prefix_lm(model)
401
- elif isinstance(model, OPTForCausalLM):
402
- return _convert_opt_causal_lm_to_prefix_lm(model)
403
  else:
404
- raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
 
 
 
 
 
405
 
406
  def add_bidirectional_mask_if_missing(batch: MutableMapping):
407
  """Attempts to add bidirectional_mask to batch if missing.
@@ -409,12 +241,16 @@ def add_bidirectional_mask_if_missing(batch: MutableMapping):
409
  Raises:
410
  KeyError if bidirectional_mask is missing and can't be inferred
411
  """
412
- if 'bidirectional_mask' not in batch:
413
- if batch.get('mode', None) == 'icl_task':
414
- batch['bidirectional_mask'] = batch['attention_mask'].clone()
415
- for (i, continuation_indices) in enumerate(batch['continuation_indices']):
416
- batch['bidirectional_mask'][i, continuation_indices] = 0
417
- elif 'labels' in batch and 'attention_mask' in batch:
418
- batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask'])
 
 
419
  else:
420
- raise KeyError('No bidirectional_mask in batch and not sure how to construct one.')
 
 
 
6
  Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
  and treat the input prompt as the prefix in `generate`.
8
  """
 
 
9
  from types import MethodType
10
  from typing import Any, List, MutableMapping, Optional, Tuple, Union
11
  import torch
 
 
 
 
12
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
13
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
14
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
15
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
16
+
17
+ _SUPPORTED_GPT_MODELS = (
18
+ GPT2LMHeadModel,
19
+ GPTJForCausalLM,
20
+ GPTNeoForCausalLM,
21
+ GPTNeoXForCausalLM,
22
+ )
23
+ CAUSAL_GPT_TYPES = Union[
24
+ GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM
25
+ ]
26
+
27
 
28
  def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
29
  """Converts a GPT-style Causal LM to a Prefix LM.
 
36
 
37
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
38
  """
39
+ if hasattr(model, "_prefix_lm_converted"):
40
  return model
41
  assert isinstance(model, _SUPPORTED_GPT_MODELS)
42
+ assert (
43
+ model.config.add_cross_attention == False
44
+ ), "Only supports GPT-style decoder-only models"
45
 
46
  def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
47
  """Helper that gets a list of the model's attention modules.
 
57
  blocks = model.transformer.h
58
  for block in blocks:
59
  if isinstance(model, GPTNeoForCausalLM):
60
+ if block.attn.attention_type != "global":
61
  continue
62
  attn_module = block.attn.attention
63
  elif isinstance(model, GPTNeoXForCausalLM):
 
66
  attn_module = block.attn
67
  attn_modules.append(attn_module)
68
  return attn_modules
 
 
69
 
70
+ setattr(model, "_original_forward", getattr(model, "forward"))
71
+ setattr(model, "_original_generate", getattr(model, "generate"))
72
+
73
+ def forward(
74
+ self: CAUSAL_GPT_TYPES,
75
+ input_ids: Optional[torch.LongTensor] = None,
76
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
77
+ attention_mask: Optional[torch.FloatTensor] = None,
78
+ bidirectional_mask: Optional[torch.Tensor] = None,
79
+ token_type_ids: Optional[torch.LongTensor] = None,
80
+ position_ids: Optional[torch.LongTensor] = None,
81
+ head_mask: Optional[torch.FloatTensor] = None,
82
+ inputs_embeds: Optional[torch.FloatTensor] = None,
83
+ labels: Optional[torch.LongTensor] = None,
84
+ use_cache: Optional[bool] = None,
85
+ output_attentions: Optional[bool] = None,
86
+ output_hidden_states: Optional[bool] = None,
87
+ return_dict: Optional[bool] = None,
88
+ ):
89
  """Wraps original forward to enable PrefixLM attention."""
90
 
91
  def call_og_forward():
92
  if isinstance(self, GPTNeoXForCausalLM):
93
+ return self._original_forward(
94
+ input_ids=input_ids,
95
+ past_key_values=past_key_values,
96
+ attention_mask=attention_mask,
97
+ head_mask=head_mask,
98
+ inputs_embeds=inputs_embeds,
99
+ labels=labels,
100
+ use_cache=use_cache,
101
+ output_attentions=output_attentions,
102
+ output_hidden_states=output_hidden_states,
103
+ return_dict=return_dict,
104
+ )
105
  else:
106
+ return self._original_forward(
107
+ input_ids=input_ids,
108
+ past_key_values=past_key_values,
109
+ attention_mask=attention_mask,
110
+ token_type_ids=token_type_ids,
111
+ position_ids=position_ids,
112
+ head_mask=head_mask,
113
+ inputs_embeds=inputs_embeds,
114
+ labels=labels,
115
+ use_cache=use_cache,
116
+ output_attentions=output_attentions,
117
+ output_hidden_states=output_hidden_states,
118
+ return_dict=return_dict,
119
+ )
120
+
121
  if bidirectional_mask is None:
122
  return call_og_forward()
123
  assert isinstance(bidirectional_mask, torch.Tensor)
 
125
  (b, s) = bidirectional_mask.shape
126
  max_length = attn_modules[0].bias.shape[-1]
127
  if s > max_length:
128
+ raise ValueError(
129
+ f"bidirectional_mask sequence length (={s}) exceeds the "
130
+ + f"max length allowed by the model ({max_length})."
131
+ )
132
  assert s <= max_length
133
  if s < max_length:
134
+ pad = torch.zeros(
135
+ (int(b), int(max_length - s)),
136
+ dtype=bidirectional_mask.dtype,
137
+ device=bidirectional_mask.device,
138
+ )
139
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
140
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
141
  for attn_module in attn_modules:
142
  assert isinstance(attn_module.bias, torch.Tensor)
143
+ attn_module.bias.data = torch.logical_or(
144
+ attn_module.bias.data, bidirectional
145
+ )
146
  output = call_og_forward()
147
  for attn_module in attn_modules:
148
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
 
157
  for attn_module in attn_modules:
158
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
159
  return output
 
 
 
 
 
 
 
 
 
 
160
 
161
+ setattr(model, "forward", MethodType(forward, model))
162
+ setattr(model, "generate", MethodType(generate, model))
163
+ setattr(model, "_prefix_lm_converted", True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  return model
165
 
 
 
166
 
167
+ _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS
168
+ CAUSAL_LM_TYPES = Union[
169
+ GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM
170
+ ]
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
174
  """Converts a HuggingFace Causal LM to a Prefix LM.
 
178
  - `GPTNeoForCausalLM`
179
  - `GPTNeoXForCausalLM`
180
  - `GPTJForCausalLM`
 
 
181
 
182
  Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
183
  `generate` method and/or select underlying methods depending on the model class.
 
227
  """
228
  if isinstance(model, _SUPPORTED_GPT_MODELS):
229
  return _convert_gpt_causal_lm_to_prefix_lm(model)
 
 
 
 
230
  else:
231
+ raise TypeError(
232
+ f"Cannot convert model to Prefix LM. "
233
+ + f"Model does not belong to set of supported HF models:"
234
+ + f"\n{_SUPPORTED_HF_MODELS}"
235
+ )
236
+
237
 
238
  def add_bidirectional_mask_if_missing(batch: MutableMapping):
239
  """Attempts to add bidirectional_mask to batch if missing.
 
241
  Raises:
242
  KeyError if bidirectional_mask is missing and can't be inferred
243
  """
244
+ if "bidirectional_mask" not in batch:
245
+ if batch.get("mode", None) == "icl_task":
246
+ batch["bidirectional_mask"] = batch["attention_mask"].clone()
247
+ for i, continuation_indices in enumerate(batch["continuation_indices"]):
248
+ batch["bidirectional_mask"][i, continuation_indices] = 0
249
+ elif "labels" in batch and "attention_mask" in batch:
250
+ batch["bidirectional_mask"] = torch.logical_and(
251
+ torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)
252
+ ).type_as(batch["attention_mask"])
253
  else:
254
+ raise KeyError(
255
+ "No bidirectional_mask in batch and not sure how to construct one."
256
+ )