xianbin commited on
Commit
9f4ac0d
1 Parent(s): 019608c

Update codes to be in line with LLM-foundry update on October 30, 2023 (#1)

Browse files

- Update codes to be in line with LLM-foundry update on October 30, 2023 (301683c3e2870b42d28b399033e5cbb9bb6475dc)

Files changed (1) hide show
  1. hf_prefixlm_converter.py +99 -263
hf_prefixlm_converter.py CHANGED
@@ -6,25 +6,24 @@ Causal LM to convert it to a Prefix LM.
6
  Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
  and treat the input prompt as the prefix in `generate`.
8
  """
9
- import math
10
- import warnings
11
  from types import MethodType
12
  from typing import Any, List, MutableMapping, Optional, Tuple, Union
13
  import torch
14
- from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
15
- from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
16
- from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
17
- from transformers.models.bloom.modeling_bloom import logging
18
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
19
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
20
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
21
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
22
- from transformers.models.opt.modeling_opt import OPTForCausalLM
23
- from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
24
- from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
25
- logger = logging.get_logger(__name__)
26
- _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
27
- CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
 
 
 
 
 
28
 
29
  def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
30
  """Converts a GPT-style Causal LM to a Prefix LM.
@@ -37,10 +36,12 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
37
 
38
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
39
  """
40
- if hasattr(model, '_prefix_lm_converted'):
41
  return model
42
  assert isinstance(model, _SUPPORTED_GPT_MODELS)
43
- assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models'
 
 
44
 
45
  def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
46
  """Helper that gets a list of the model's attention modules.
@@ -56,7 +57,7 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
56
  blocks = model.transformer.h
57
  for block in blocks:
58
  if isinstance(model, GPTNeoForCausalLM):
59
- if block.attn.attention_type != 'global':
60
  continue
61
  attn_module = block.attn.attention
62
  elif isinstance(model, GPTNeoXForCausalLM):
@@ -65,17 +66,58 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
65
  attn_module = block.attn
66
  attn_modules.append(attn_module)
67
  return attn_modules
68
- setattr(model, '_original_forward', getattr(model, 'forward'))
69
- setattr(model, '_original_generate', getattr(model, 'generate'))
70
 
71
- def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  """Wraps original forward to enable PrefixLM attention."""
73
 
74
  def call_og_forward():
75
  if isinstance(self, GPTNeoXForCausalLM):
76
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
77
  else:
78
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if bidirectional_mask is None:
80
  return call_og_forward()
81
  assert isinstance(bidirectional_mask, torch.Tensor)
@@ -83,15 +125,24 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
83
  (b, s) = bidirectional_mask.shape
84
  max_length = attn_modules[0].bias.shape[-1]
85
  if s > max_length:
86
- raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).')
 
 
 
87
  assert s <= max_length
88
  if s < max_length:
89
- pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
 
 
 
 
90
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
91
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
92
  for attn_module in attn_modules:
93
  assert isinstance(attn_module.bias, torch.Tensor)
94
- attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
 
 
95
  output = call_og_forward()
96
  for attn_module in attn_modules:
97
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
@@ -106,236 +157,18 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
106
  for attn_module in attn_modules:
107
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
108
  return output
109
- setattr(model, 'forward', MethodType(forward, model))
110
- setattr(model, 'generate', MethodType(generate, model))
111
- setattr(model, '_prefix_lm_converted', True)
112
- return model
113
-
114
- def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
115
- """Converts a BLOOM Causal LM to a Prefix LM.
116
-
117
- Supported HuggingFace model classes:
118
- - `BloomForCausalLM`
119
 
120
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
121
- """
122
- if hasattr(model, '_prefix_lm_converted'):
123
- return model
124
- assert isinstance(model, BloomForCausalLM)
125
- assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'
126
-
127
- def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
128
- combined_attention_mask = None
129
- device = attention_mask.device
130
- (_, src_length) = input_shape
131
- if src_length > 1:
132
- combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
133
- if bidirectional_mask is not None:
134
- assert attention_mask.shape == bidirectional_mask.shape
135
- expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
136
- combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
137
- expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
138
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
139
- return combined_attention_mask
140
-
141
- def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
142
- num_heads = self.config.n_head
143
- closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
144
- base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
145
- powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
146
- slopes = torch.pow(base, powers)
147
- if closest_power_of_2 != num_heads:
148
- extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32)
149
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
150
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
151
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
152
- qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
153
- ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
154
- diffs = qa - ka + key_length - query_length
155
- diffs = -diffs.abs()
156
- alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
157
- alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
158
- return alibi.to(dtype)
159
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
160
-
161
- def transformer_forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments: Any) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
162
- if deprecated_arguments.pop('position_ids', False) is not False:
163
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
164
- if len(deprecated_arguments) > 0:
165
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
166
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
167
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
168
- use_cache = use_cache if use_cache is not None else self.config.use_cache
169
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
170
- if input_ids is not None and inputs_embeds is not None:
171
- raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
172
- elif input_ids is not None:
173
- (batch_size, seq_length) = input_ids.shape
174
- elif inputs_embeds is not None:
175
- (batch_size, seq_length, _) = inputs_embeds.shape
176
- else:
177
- raise ValueError('You have to specify either input_ids or inputs_embeds')
178
- if past_key_values is None:
179
- past_key_values = tuple([None] * len(self.h))
180
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
181
- if inputs_embeds is None:
182
- inputs_embeds = self.word_embeddings(input_ids)
183
- hidden_states = self.word_embeddings_layernorm(inputs_embeds)
184
- presents = () if use_cache else None
185
- all_self_attentions = () if output_attentions else None
186
- all_hidden_states = () if output_hidden_states else None
187
- seq_length_with_past = seq_length
188
- past_key_values_length = 0
189
- if past_key_values[0] is not None:
190
- tmp = past_key_values[0][0]
191
- past_key_values_length = tmp.shape[2]
192
- seq_length_with_past = seq_length_with_past + past_key_values_length
193
- if attention_mask is None:
194
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
195
- else:
196
- attention_mask = attention_mask.to(hidden_states.device)
197
- alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device)
198
- causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length)
199
- for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
200
- if output_hidden_states:
201
- hst = (hidden_states,)
202
- all_hidden_states = all_hidden_states + hst
203
- if self.gradient_checkpointing and self.training:
204
- if use_cache:
205
- logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
206
- use_cache = False
207
-
208
- def create_custom_forward(module: torch.nn.Module):
209
-
210
- def custom_forward(*inputs: Any):
211
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
212
- return custom_forward
213
- outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
214
- else:
215
- outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)
216
- hidden_states = outputs[0]
217
- if use_cache is True:
218
- presents = presents + (outputs[1],)
219
- if output_attentions:
220
- oa = (outputs[2 if use_cache else 1],)
221
- all_self_attentions = all_self_attentions + oa
222
- hidden_states = self.ln_f(hidden_states)
223
- if output_hidden_states:
224
- hst = (hidden_states,)
225
- all_hidden_states = all_hidden_states + hst
226
- if not return_dict:
227
- return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
228
- return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
229
- setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
230
- setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
231
- setattr(model.transformer, 'forward', MethodType(transformer_forward, model.transformer))
232
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
233
-
234
- def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments: Any) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
235
- """Replacement forward method for BloomCausalLM."""
236
- if deprecated_arguments.pop('position_ids', False) is not False:
237
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
238
- if len(deprecated_arguments) > 0:
239
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
240
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
241
- transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
242
- hidden_states = transformer_outputs[0]
243
- lm_logits = self.lm_head(hidden_states)
244
- loss = None
245
- if labels is not None:
246
- shift_logits = lm_logits[..., :-1, :].contiguous()
247
- shift_labels = labels[..., 1:].contiguous()
248
- (batch_size, seq_length, vocab_size) = shift_logits.shape
249
- loss_fct = CrossEntropyLoss()
250
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
251
- if not return_dict:
252
- output = (lm_logits,) + transformer_outputs[1:]
253
- return (loss,) + output if loss is not None else output
254
- return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
255
-
256
- def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs: Any) -> dict:
257
- del kwargs
258
- if past:
259
- input_ids = input_ids[:, -1].unsqueeze(-1)
260
- bidirectional_mask = None
261
- if past[0][0].shape[0] == input_ids.shape[0]:
262
- past = self._convert_to_bloom_cache(past)
263
- else:
264
- bidirectional_mask = torch.ones_like(input_ids)
265
- return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask}
266
- setattr(model, 'forward', MethodType(forward, model))
267
- setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model))
268
- setattr(model, '_prefix_lm_converted', True)
269
  return model
270
 
271
- def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
272
- """Converts an OPT Causal LM to a Prefix LM.
273
 
274
- Supported HuggingFace model classes:
275
- - `OPTForCausalLM`
 
 
276
 
277
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
278
- """
279
- if hasattr(model, '_prefix_lm_converted'):
280
- return model
281
- assert isinstance(model, OPTForCausalLM)
282
- assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'
283
- setattr(model, '_original_forward', getattr(model, 'forward'))
284
- setattr(model, '_original_generate', getattr(model, 'generate'))
285
- model.model.decoder.bidirectional_mask = None
286
-
287
- def _prepare_decoder_attention_mask(self: torch.nn.Module, attention_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], inputs_embeds: Optional[torch.Tensor], past_key_values_length: int):
288
- combined_attention_mask = None
289
- if input_shape[-1] > 1:
290
- assert inputs_embeds is not None
291
- if self.bidirectional_mask == 'g':
292
- (bsz, src_length) = input_shape
293
- combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
294
- else:
295
- combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
296
- if self.bidirectional_mask is not None:
297
- assert attention_mask is not None
298
- assert attention_mask.shape == self.bidirectional_mask.shape
299
- expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
300
- combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
301
- if attention_mask is not None:
302
- assert inputs_embeds is not None
303
- expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
304
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
305
- return combined_attention_mask
306
- setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
307
-
308
- def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
309
-
310
- def call_og_forward():
311
- return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
312
- if bidirectional_mask is None:
313
- return call_og_forward()
314
- self.model.decoder.bidirectional_mask = bidirectional_mask
315
- try:
316
- outputs = call_og_forward()
317
- except:
318
- self.model.decoder.bidirectional_mask = None
319
- raise
320
- self.model.decoder.bidirectional_mask = None
321
- return outputs
322
-
323
- def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Any):
324
- """Wraps original generate to enable PrefixLM-style attention."""
325
- self.model.decoder.bidirectional_mask = 'g'
326
- try:
327
- output = self._original_generate(*args, **kwargs)
328
- except:
329
- self.model.decoder.bidirectional_mask = None
330
- raise
331
- self.model.decoder.bidirectional_mask = None
332
- return output
333
- setattr(model, 'forward', MethodType(forward, model))
334
- setattr(model, 'generate', MethodType(generate, model))
335
- setattr(model, '_prefix_lm_converted', True)
336
- return model
337
- _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
338
- CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
339
 
340
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
341
  """Converts a HuggingFace Causal LM to a Prefix LM.
@@ -345,8 +178,6 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
345
  - `GPTNeoForCausalLM`
346
  - `GPTNeoXForCausalLM`
347
  - `GPTJForCausalLM`
348
- - `BloomForCausalLM`
349
- - `OPTForCausalLM`
350
 
351
  Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
352
  `generate` method and/or select underlying methods depending on the model class.
@@ -396,12 +227,13 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
396
  """
397
  if isinstance(model, _SUPPORTED_GPT_MODELS):
398
  return _convert_gpt_causal_lm_to_prefix_lm(model)
399
- elif isinstance(model, BloomForCausalLM):
400
- return _convert_bloom_causal_lm_to_prefix_lm(model)
401
- elif isinstance(model, OPTForCausalLM):
402
- return _convert_opt_causal_lm_to_prefix_lm(model)
403
  else:
404
- raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
 
 
 
 
 
405
 
406
  def add_bidirectional_mask_if_missing(batch: MutableMapping):
407
  """Attempts to add bidirectional_mask to batch if missing.
@@ -409,12 +241,16 @@ def add_bidirectional_mask_if_missing(batch: MutableMapping):
409
  Raises:
410
  KeyError if bidirectional_mask is missing and can't be inferred
411
  """
412
- if 'bidirectional_mask' not in batch:
413
- if batch.get('mode', None) == 'icl_task':
414
- batch['bidirectional_mask'] = batch['attention_mask'].clone()
415
- for (i, continuation_indices) in enumerate(batch['continuation_indices']):
416
- batch['bidirectional_mask'][i, continuation_indices] = 0
417
- elif 'labels' in batch and 'attention_mask' in batch:
418
- batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask'])
 
 
419
  else:
420
- raise KeyError('No bidirectional_mask in batch and not sure how to construct one.')
 
 
 
6
  Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
  and treat the input prompt as the prefix in `generate`.
8
  """
 
 
9
  from types import MethodType
10
  from typing import Any, List, MutableMapping, Optional, Tuple, Union
11
  import torch
 
 
 
 
12
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
13
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
14
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
15
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
16
+
17
+ _SUPPORTED_GPT_MODELS = (
18
+ GPT2LMHeadModel,
19
+ GPTJForCausalLM,
20
+ GPTNeoForCausalLM,
21
+ GPTNeoXForCausalLM,
22
+ )
23
+ CAUSAL_GPT_TYPES = Union[
24
+ GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM
25
+ ]
26
+
27
 
28
  def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
29
  """Converts a GPT-style Causal LM to a Prefix LM.
 
36
 
37
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
38
  """
39
+ if hasattr(model, "_prefix_lm_converted"):
40
  return model
41
  assert isinstance(model, _SUPPORTED_GPT_MODELS)
42
+ assert (
43
+ model.config.add_cross_attention == False
44
+ ), "Only supports GPT-style decoder-only models"
45
 
46
  def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
47
  """Helper that gets a list of the model's attention modules.
 
57
  blocks = model.transformer.h
58
  for block in blocks:
59
  if isinstance(model, GPTNeoForCausalLM):
60
+ if block.attn.attention_type != "global":
61
  continue
62
  attn_module = block.attn.attention
63
  elif isinstance(model, GPTNeoXForCausalLM):
 
66
  attn_module = block.attn
67
  attn_modules.append(attn_module)
68
  return attn_modules
 
 
69
 
70
+ setattr(model, "_original_forward", getattr(model, "forward"))
71
+ setattr(model, "_original_generate", getattr(model, "generate"))
72
+
73
+ def forward(
74
+ self: CAUSAL_GPT_TYPES,
75
+ input_ids: Optional[torch.LongTensor] = None,
76
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
77
+ attention_mask: Optional[torch.FloatTensor] = None,
78
+ bidirectional_mask: Optional[torch.Tensor] = None,
79
+ token_type_ids: Optional[torch.LongTensor] = None,
80
+ position_ids: Optional[torch.LongTensor] = None,
81
+ head_mask: Optional[torch.FloatTensor] = None,
82
+ inputs_embeds: Optional[torch.FloatTensor] = None,
83
+ labels: Optional[torch.LongTensor] = None,
84
+ use_cache: Optional[bool] = None,
85
+ output_attentions: Optional[bool] = None,
86
+ output_hidden_states: Optional[bool] = None,
87
+ return_dict: Optional[bool] = None,
88
+ ):
89
  """Wraps original forward to enable PrefixLM attention."""
90
 
91
  def call_og_forward():
92
  if isinstance(self, GPTNeoXForCausalLM):
93
+ return self._original_forward(
94
+ input_ids=input_ids,
95
+ past_key_values=past_key_values,
96
+ attention_mask=attention_mask,
97
+ head_mask=head_mask,
98
+ inputs_embeds=inputs_embeds,
99
+ labels=labels,
100
+ use_cache=use_cache,
101
+ output_attentions=output_attentions,
102
+ output_hidden_states=output_hidden_states,
103
+ return_dict=return_dict,
104
+ )
105
  else:
106
+ return self._original_forward(
107
+ input_ids=input_ids,
108
+ past_key_values=past_key_values,
109
+ attention_mask=attention_mask,
110
+ token_type_ids=token_type_ids,
111
+ position_ids=position_ids,
112
+ head_mask=head_mask,
113
+ inputs_embeds=inputs_embeds,
114
+ labels=labels,
115
+ use_cache=use_cache,
116
+ output_attentions=output_attentions,
117
+ output_hidden_states=output_hidden_states,
118
+ return_dict=return_dict,
119
+ )
120
+
121
  if bidirectional_mask is None:
122
  return call_og_forward()
123
  assert isinstance(bidirectional_mask, torch.Tensor)
 
125
  (b, s) = bidirectional_mask.shape
126
  max_length = attn_modules[0].bias.shape[-1]
127
  if s > max_length:
128
+ raise ValueError(
129
+ f"bidirectional_mask sequence length (={s}) exceeds the "
130
+ + f"max length allowed by the model ({max_length})."
131
+ )
132
  assert s <= max_length
133
  if s < max_length:
134
+ pad = torch.zeros(
135
+ (int(b), int(max_length - s)),
136
+ dtype=bidirectional_mask.dtype,
137
+ device=bidirectional_mask.device,
138
+ )
139
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
140
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
141
  for attn_module in attn_modules:
142
  assert isinstance(attn_module.bias, torch.Tensor)
143
+ attn_module.bias.data = torch.logical_or(
144
+ attn_module.bias.data, bidirectional
145
+ )
146
  output = call_og_forward()
147
  for attn_module in attn_modules:
148
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
 
157
  for attn_module in attn_modules:
158
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
159
  return output
 
 
 
 
 
 
 
 
 
 
160
 
161
+ setattr(model, "forward", MethodType(forward, model))
162
+ setattr(model, "generate", MethodType(generate, model))
163
+ setattr(model, "_prefix_lm_converted", True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  return model
165
 
 
 
166
 
167
+ _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS
168
+ CAUSAL_LM_TYPES = Union[
169
+ GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM
170
+ ]
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
174
  """Converts a HuggingFace Causal LM to a Prefix LM.
 
178
  - `GPTNeoForCausalLM`
179
  - `GPTNeoXForCausalLM`
180
  - `GPTJForCausalLM`
 
 
181
 
182
  Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
183
  `generate` method and/or select underlying methods depending on the model class.
 
227
  """
228
  if isinstance(model, _SUPPORTED_GPT_MODELS):
229
  return _convert_gpt_causal_lm_to_prefix_lm(model)
 
 
 
 
230
  else:
231
+ raise TypeError(
232
+ f"Cannot convert model to Prefix LM. "
233
+ + f"Model does not belong to set of supported HF models:"
234
+ + f"\n{_SUPPORTED_HF_MODELS}"
235
+ )
236
+
237
 
238
  def add_bidirectional_mask_if_missing(batch: MutableMapping):
239
  """Attempts to add bidirectional_mask to batch if missing.
 
241
  Raises:
242
  KeyError if bidirectional_mask is missing and can't be inferred
243
  """
244
+ if "bidirectional_mask" not in batch:
245
+ if batch.get("mode", None) == "icl_task":
246
+ batch["bidirectional_mask"] = batch["attention_mask"].clone()
247
+ for i, continuation_indices in enumerate(batch["continuation_indices"]):
248
+ batch["bidirectional_mask"][i, continuation_indices] = 0
249
+ elif "labels" in batch and "attention_mask" in batch:
250
+ batch["bidirectional_mask"] = torch.logical_and(
251
+ torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)
252
+ ).type_as(batch["attention_mask"])
253
  else:
254
+ raise KeyError(
255
+ "No bidirectional_mask in batch and not sure how to construct one."
256
+ )