efederici commited on
Commit
b1dbc68
1 Parent(s): d863a84

Update hf_prefixlm_converter.py

Browse files
Files changed (1) hide show
  1. hf_prefixlm_converter.py +6 -241
hf_prefixlm_converter.py CHANGED
@@ -6,23 +6,13 @@ Causal LM to convert it to a Prefix LM.
6
  Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
  and treat the input prompt as the prefix in `generate`.
8
  """
9
- import math
10
- import warnings
11
  from types import MethodType
12
- from typing import Any, Dict, List, Optional, Tuple, Union
13
  import torch
14
- from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
15
- from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
16
- from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
17
- from transformers.models.bloom.modeling_bloom import logging
18
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
19
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
20
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
21
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
22
- from transformers.models.opt.modeling_opt import OPTForCausalLM
23
- from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
24
- from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
25
- logger = logging.get_logger(__name__)
26
  _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
27
  CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
28
 
@@ -90,13 +80,14 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
90
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
91
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
92
  for attn_module in attn_modules:
 
93
  attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
94
  output = call_og_forward()
95
  for attn_module in attn_modules:
96
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
97
  return output
98
 
99
- def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
100
  """Wraps original generate to enable PrefixLM attention."""
101
  attn_modules = _get_attn_modules(model)
102
  for attn_module in attn_modules:
@@ -109,228 +100,8 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
109
  setattr(model, 'generate', MethodType(generate, model))
110
  setattr(model, '_prefix_lm_converted', True)
111
  return model
112
-
113
- def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
114
- """Converts a BLOOM Causal LM to a Prefix LM.
115
-
116
- Supported HuggingFace model classes:
117
- - `BloomForCausalLM`
118
-
119
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
120
- """
121
- if hasattr(model, '_prefix_lm_converted'):
122
- return model
123
- assert isinstance(model, BloomForCausalLM)
124
- assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'
125
-
126
- def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
127
- combined_attention_mask = None
128
- device = attention_mask.device
129
- (_, src_length) = input_shape
130
- if src_length > 1:
131
- combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
132
- if bidirectional_mask is not None:
133
- assert attention_mask.shape == bidirectional_mask.shape
134
- expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
135
- combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
136
- expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
137
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
138
- return combined_attention_mask
139
-
140
- def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
141
- num_heads = self.config.n_head
142
- closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
143
- base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
144
- powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
145
- slopes = torch.pow(base, powers)
146
- if closest_power_of_2 != num_heads:
147
- extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32)
148
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
149
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
150
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
151
- qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
152
- ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
153
- diffs = qa - ka + key_length - query_length
154
- diffs = -diffs.abs()
155
- alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
156
- alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
157
- return alibi.to(dtype)
158
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
159
-
160
- def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
161
- if deprecated_arguments.pop('position_ids', False) is not False:
162
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
163
- if len(deprecated_arguments) > 0:
164
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
165
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
166
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
167
- use_cache = use_cache if use_cache is not None else self.config.use_cache
168
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
- if input_ids is not None and inputs_embeds is not None:
170
- raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
171
- elif input_ids is not None:
172
- (batch_size, seq_length) = input_ids.shape
173
- elif inputs_embeds is not None:
174
- (batch_size, seq_length, _) = inputs_embeds.shape
175
- else:
176
- raise ValueError('You have to specify either input_ids or inputs_embeds')
177
- if past_key_values is None:
178
- past_key_values = tuple([None] * len(self.h))
179
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
180
- if inputs_embeds is None:
181
- inputs_embeds = self.word_embeddings(input_ids)
182
- hidden_states = self.word_embeddings_layernorm(inputs_embeds)
183
- presents = () if use_cache else None
184
- all_self_attentions = () if output_attentions else None
185
- all_hidden_states = () if output_hidden_states else None
186
- seq_length_with_past = seq_length
187
- past_key_values_length = 0
188
- if past_key_values[0] is not None:
189
- tmp = past_key_values[0][0]
190
- past_key_values_length = tmp.shape[2]
191
- seq_length_with_past = seq_length_with_past + past_key_values_length
192
- if attention_mask is None:
193
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
194
- else:
195
- attention_mask = attention_mask.to(hidden_states.device)
196
- alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device)
197
- causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length)
198
- for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
199
- if output_hidden_states:
200
- hst = (hidden_states,)
201
- all_hidden_states = all_hidden_states + hst
202
- if self.gradient_checkpointing and self.training:
203
- if use_cache:
204
- logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
205
- use_cache = False
206
-
207
- def create_custom_forward(module):
208
-
209
- def custom_forward(*inputs):
210
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
211
- return custom_forward
212
- outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
213
- else:
214
- outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)
215
- hidden_states = outputs[0]
216
- if use_cache is True:
217
- presents = presents + (outputs[1],)
218
- if output_attentions:
219
- oa = (outputs[2 if use_cache else 1],)
220
- all_self_attentions = all_self_attentions + oa
221
- hidden_states = self.ln_f(hidden_states)
222
- if output_hidden_states:
223
- hst = (hidden_states,)
224
- all_hidden_states = all_hidden_states + hst
225
- if not return_dict:
226
- return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
227
- return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
228
- setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
229
- setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
230
- setattr(model.transformer, 'forward', MethodType(forward, model.transformer))
231
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
232
-
233
- def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
234
- """Replacement forward method for BloomCausalLM."""
235
- if deprecated_arguments.pop('position_ids', False) is not False:
236
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
237
- if len(deprecated_arguments) > 0:
238
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
239
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
240
- transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
241
- hidden_states = transformer_outputs[0]
242
- lm_logits = self.lm_head(hidden_states)
243
- loss = None
244
- if labels is not None:
245
- shift_logits = lm_logits[..., :-1, :].contiguous()
246
- shift_labels = labels[..., 1:].contiguous()
247
- (batch_size, seq_length, vocab_size) = shift_logits.shape
248
- loss_fct = CrossEntropyLoss()
249
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
250
- if not return_dict:
251
- output = (lm_logits,) + transformer_outputs[1:]
252
- return (loss,) + output if loss is not None else output
253
- return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
254
-
255
- def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict:
256
- if past:
257
- input_ids = input_ids[:, -1].unsqueeze(-1)
258
- bidirectional_mask = None
259
- if past[0][0].shape[0] == input_ids.shape[0]:
260
- past = self._convert_to_bloom_cache(past)
261
- else:
262
- bidirectional_mask = torch.ones_like(input_ids)
263
- return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask}
264
- setattr(model, 'forward', MethodType(forward, model))
265
- setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model))
266
- setattr(model, '_prefix_lm_converted', True)
267
- return model
268
-
269
- def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
270
- """Converts an OPT Causal LM to a Prefix LM.
271
-
272
- Supported HuggingFace model classes:
273
- - `OPTForCausalLM`
274
-
275
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
276
- """
277
- if hasattr(model, '_prefix_lm_converted'):
278
- return model
279
- assert isinstance(model, OPTForCausalLM)
280
- assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'
281
- setattr(model, '_original_forward', getattr(model, 'forward'))
282
- setattr(model, '_original_generate', getattr(model, 'generate'))
283
- model.model.decoder.bidirectional_mask = None
284
-
285
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
286
- combined_attention_mask = None
287
- if input_shape[-1] > 1:
288
- if self.bidirectional_mask == 'g':
289
- (bsz, src_length) = input_shape
290
- combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
291
- else:
292
- combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
293
- if self.bidirectional_mask is not None:
294
- assert attention_mask.shape == self.bidirectional_mask.shape
295
- expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
296
- combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
297
- if attention_mask is not None:
298
- expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
299
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
300
- return combined_attention_mask
301
- setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
302
-
303
- def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
304
-
305
- def call_og_forward():
306
- return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
307
- if bidirectional_mask is None:
308
- return call_og_forward()
309
- self.model.decoder.bidirectional_mask = bidirectional_mask
310
- try:
311
- outputs = call_og_forward()
312
- except:
313
- self.model.decoder.bidirectional_mask = None
314
- raise
315
- self.model.decoder.bidirectional_mask = None
316
- return outputs
317
-
318
- def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
319
- """Wraps original generate to enable PrefixLM-style attention."""
320
- self.model.decoder.bidirectional_mask = 'g'
321
- try:
322
- output = self._original_generate(*args, **kwargs)
323
- except:
324
- self.model.decoder.bidirectional_mask = None
325
- raise
326
- self.model.decoder.bidirectional_mask = None
327
- return output
328
- setattr(model, 'forward', MethodType(forward, model))
329
- setattr(model, 'generate', MethodType(generate, model))
330
- setattr(model, '_prefix_lm_converted', True)
331
- return model
332
- _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
333
- CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
334
 
335
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
336
  """Converts a HuggingFace Causal LM to a Prefix LM.
@@ -340,8 +111,6 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
340
  - `GPTNeoForCausalLM`
341
  - `GPTNeoXForCausalLM`
342
  - `GPTJForCausalLM`
343
- - `BloomForCausalLM`
344
- - `OPTForCausalLM`
345
 
346
  Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
347
  `generate` method and/or select underlying methods depending on the model class.
@@ -391,14 +160,10 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
391
  """
392
  if isinstance(model, _SUPPORTED_GPT_MODELS):
393
  return _convert_gpt_causal_lm_to_prefix_lm(model)
394
- elif isinstance(model, BloomForCausalLM):
395
- return _convert_bloom_causal_lm_to_prefix_lm(model)
396
- elif isinstance(model, OPTForCausalLM):
397
- return _convert_opt_causal_lm_to_prefix_lm(model)
398
  else:
399
  raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
400
 
401
- def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
402
  """Attempts to add bidirectional_mask to batch if missing.
403
 
404
  Raises:
 
6
  Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
  and treat the input prompt as the prefix in `generate`.
8
  """
 
 
9
  from types import MethodType
10
+ from typing import Any, List, MutableMapping, Optional, Tuple, Union
11
  import torch
 
 
 
 
12
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
13
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
14
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
15
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
 
 
 
 
16
  _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
17
  CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
18
 
 
80
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
81
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
82
  for attn_module in attn_modules:
83
+ assert isinstance(attn_module.bias, torch.Tensor)
84
  attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
85
  output = call_og_forward()
86
  for attn_module in attn_modules:
87
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
88
  return output
89
 
90
+ def generate(self: CAUSAL_GPT_TYPES, *args: Any, **kwargs: Any):
91
  """Wraps original generate to enable PrefixLM attention."""
92
  attn_modules = _get_attn_modules(model)
93
  for attn_module in attn_modules:
 
100
  setattr(model, 'generate', MethodType(generate, model))
101
  setattr(model, '_prefix_lm_converted', True)
102
  return model
103
+ _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS
104
+ CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
107
  """Converts a HuggingFace Causal LM to a Prefix LM.
 
111
  - `GPTNeoForCausalLM`
112
  - `GPTNeoXForCausalLM`
113
  - `GPTJForCausalLM`
 
 
114
 
115
  Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
116
  `generate` method and/or select underlying methods depending on the model class.
 
160
  """
161
  if isinstance(model, _SUPPORTED_GPT_MODELS):
162
  return _convert_gpt_causal_lm_to_prefix_lm(model)
 
 
 
 
163
  else:
164
  raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
165
 
166
+ def add_bidirectional_mask_if_missing(batch: MutableMapping):
167
  """Attempts to add bidirectional_mask to batch if missing.
168
 
169
  Raises: