efederici commited on
Commit
31eb985
1 Parent(s): f974567

Update modeling_mpt.py

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +43 -11
modeling_mpt.py CHANGED
@@ -12,17 +12,23 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokeniz
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
 
15
  from .norm import NORM_CLASS_REGISTRY
16
  from .configuration_mpt import MPTConfig
17
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
 
 
 
 
21
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
22
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
 
26
 
27
  class MPTModel(MPTPreTrainedModel):
28
 
@@ -34,14 +40,19 @@ class MPTModel(MPTPreTrainedModel):
34
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
35
  self.alibi = config.attn_config['alibi']
36
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
 
 
 
 
 
37
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
38
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
39
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
40
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
41
  self.embedding_fraction = config.embedding_fraction
42
- self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
43
  if not self.alibi:
44
- self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
45
  self.emb_drop = nn.Dropout(config.emb_pdrop)
46
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
47
  self.norm_f = norm_class(config.d_model, device=config.init_device)
@@ -96,7 +107,8 @@ class MPTModel(MPTPreTrainedModel):
96
  if attn_bias is None:
97
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
98
  else:
99
- attn_bias = attn_bias[:, :, :, -s_k:]
 
100
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
101
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
102
  min_val = torch.finfo(attn_bias.dtype).min
@@ -128,7 +140,7 @@ class MPTModel(MPTPreTrainedModel):
128
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
129
  return attn_bias
130
 
131
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
132
  return_dict = return_dict if return_dict is not None else self.config.return_dict
133
  use_cache = use_cache if use_cache is not None else self.config.use_cache
134
  if attention_mask is not None:
@@ -138,11 +150,14 @@ class MPTModel(MPTPreTrainedModel):
138
  if not return_dict:
139
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
140
  if output_attentions:
141
- raise NotImplementedError('output_attentions is not implemented yet for MPT')
 
142
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
143
  raise NotImplementedError('MPT does not support training with left padding.')
144
  if self.prefix_lm and prefix_mask is None:
145
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
 
146
  if self.training:
147
  if self.attn_uses_sequence_id and sequence_id is None:
148
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
@@ -159,6 +174,8 @@ class MPTModel(MPTPreTrainedModel):
159
  if len(past_key_values) != self.config.n_layers:
160
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
161
  past_position = past_key_values[0][0].size(1)
 
 
162
  if S + past_position > self.config.max_seq_len:
163
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
164
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
@@ -172,20 +189,27 @@ class MPTModel(MPTPreTrainedModel):
172
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
173
  assert isinstance(self.emb_drop, nn.Module)
174
  x = self.emb_drop(x_shrunk)
175
- (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
176
  if use_cache and past_key_values is None:
177
  past_key_values = [() for _ in range(self.config.n_layers)]
178
  all_hidden_states = () if output_hidden_states else None
 
179
  for (b_idx, block) in enumerate(self.blocks):
180
  if output_hidden_states:
181
  assert all_hidden_states is not None
182
  all_hidden_states = all_hidden_states + (x,)
183
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
184
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
185
  if past_key_values is not None:
186
  past_key_values[b_idx] = past_key_value
 
 
 
187
  x = self.norm_f(x)
188
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
 
 
 
189
 
190
  def param_init_fn(self, module):
191
  init_fn_name = self.config.init_config['name']
@@ -203,7 +227,13 @@ class MPTForCausalLM(MPTPreTrainedModel):
203
  super().__init__(config)
204
  if not config.tie_word_embeddings:
205
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
 
206
  self.transformer = MPTModel(config)
 
 
 
 
 
207
  self.logit_scale = None
208
  if config.logit_scale is not None:
209
  logit_scale = config.logit_scale
@@ -232,11 +262,13 @@ class MPTForCausalLM(MPTPreTrainedModel):
232
  def get_decoder(self):
233
  return self.transformer
234
 
235
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
236
  return_dict = return_dict if return_dict is not None else self.config.return_dict
237
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
238
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
239
- logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
240
  if self.logit_scale is not None:
241
  if self.logit_scale == 0:
242
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
@@ -246,7 +278,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
246
  labels = torch.roll(labels, shifts=-1)
247
  labels[:, -1] = -100
248
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
249
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
250
 
251
  def param_init_fn(self, module):
252
  init_fn_name = self.config.init_config['name']
 
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
+ from .custom_embedding import SharedEmbedding
16
  from .norm import NORM_CLASS_REGISTRY
17
  from .configuration_mpt import MPTConfig
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
19
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
20
  from .meta_init_context import init_empty_weights
21
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
22
+ try:
23
+ from .flash_attn_triton import flash_attn_func
24
+ except:
25
+ pass
26
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
27
 
28
  class MPTPreTrainedModel(PreTrainedModel):
29
  config_class = MPTConfig
30
  base_model_prefix = 'model'
31
+ _no_split_modules = ['MPTBlock']
32
 
33
  class MPTModel(MPTPreTrainedModel):
34
 
 
40
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
41
  self.alibi = config.attn_config['alibi']
42
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
43
+ if config.init_device == 'mixed':
44
+ if dist.get_local_rank() == 0:
45
+ config.init_device = 'cpu'
46
+ else:
47
+ config.init_device = 'meta'
48
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
49
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
50
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
51
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
52
  self.embedding_fraction = config.embedding_fraction
53
+ self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
54
  if not self.alibi:
55
+ self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
56
  self.emb_drop = nn.Dropout(config.emb_pdrop)
57
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
58
  self.norm_f = norm_class(config.d_model, device=config.init_device)
 
107
  if attn_bias is None:
108
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
109
  else:
110
+ _s_k = max(0, attn_bias.size(-1) - s_k)
111
+ attn_bias = attn_bias[:, :, :, _s_k:]
112
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
113
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
114
  min_val = torch.finfo(attn_bias.dtype).min
 
140
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
141
  return attn_bias
142
 
143
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None):
144
  return_dict = return_dict if return_dict is not None else self.config.return_dict
145
  use_cache = use_cache if use_cache is not None else self.config.use_cache
146
  if attention_mask is not None:
 
150
  if not return_dict:
151
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
152
  if output_attentions:
153
+ if self.attn_impl != 'torch':
154
+ raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
155
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
156
  raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
159
+ if inputs_embeds is not None:
160
+ raise NotImplementedError('inputs_embeds is not implemented for MPT.')
161
  if self.training:
162
  if self.attn_uses_sequence_id and sequence_id is None:
163
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
 
174
  if len(past_key_values) != self.config.n_layers:
175
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
176
  past_position = past_key_values[0][0].size(1)
177
+ if self.attn_impl == 'torch':
178
+ past_position = past_key_values[0][0].size(3)
179
  if S + past_position > self.config.max_seq_len:
180
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
181
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
 
189
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
190
  assert isinstance(self.emb_drop, nn.Module)
191
  x = self.emb_drop(x_shrunk)
192
+ (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
193
  if use_cache and past_key_values is None:
194
  past_key_values = [() for _ in range(self.config.n_layers)]
195
  all_hidden_states = () if output_hidden_states else None
196
+ all_self_attns = () if output_attentions else None
197
  for (b_idx, block) in enumerate(self.blocks):
198
  if output_hidden_states:
199
  assert all_hidden_states is not None
200
  all_hidden_states = all_hidden_states + (x,)
201
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
202
+ (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
203
  if past_key_values is not None:
204
  past_key_values[b_idx] = past_key_value
205
+ if output_attentions:
206
+ assert all_self_attns is not None
207
+ all_self_attns = all_self_attns + (attn_weights,)
208
  x = self.norm_f(x)
209
+ if output_hidden_states:
210
+ assert all_hidden_states is not None
211
+ all_hidden_states = all_hidden_states + (x,)
212
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
213
 
214
  def param_init_fn(self, module):
215
  init_fn_name = self.config.init_config['name']
 
227
  super().__init__(config)
228
  if not config.tie_word_embeddings:
229
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
230
+ print(f'Instantiating an MPTForCausalLM model from {__file__}')
231
  self.transformer = MPTModel(config)
232
+ for child in self.transformer.children():
233
+ if isinstance(child, torch.nn.ModuleList):
234
+ continue
235
+ if isinstance(child, torch.nn.Module):
236
+ child._fsdp_wrap = True
237
  self.logit_scale = None
238
  if config.logit_scale is not None:
239
  logit_scale = config.logit_scale
 
262
  def get_decoder(self):
263
  return self.transformer
264
 
265
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None):
266
  return_dict = return_dict if return_dict is not None else self.config.return_dict
267
  use_cache = use_cache if use_cache is not None else self.config.use_cache
268
+ if inputs_embeds is not None:
269
+ raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
270
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
271
+ logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
272
  if self.logit_scale is not None:
273
  if self.logit_scale == 0:
274
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
278
  labels = torch.roll(labels, shifts=-1)
279
  labels[:, -1] = -100
280
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
281
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
282
 
283
  def param_init_fn(self, module):
284
  init_fn_name = self.config.init_config['name']