Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference
sam-mosaic commited on
Commit
e913229
1 Parent(s): 996ffc5

Upload folder using huggingface_hub

Browse files
attention.py CHANGED
@@ -46,7 +46,7 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_
46
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
47
  if is_causal and (not q.size(2) == 1):
48
  s = max(s_q, s_k)
49
- causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
50
  causal_mask = causal_mask.tril()
51
  causal_mask = causal_mask.to(torch.bool)
52
  causal_mask = ~causal_mask
 
46
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
47
  if is_causal and (not q.size(2) == 1):
48
  s = max(s_q, s_k)
49
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
50
  causal_mask = causal_mask.tril()
51
  causal_mask = causal_mask.to(torch.bool)
52
  causal_mask = ~causal_mask
config.json CHANGED
@@ -27,9 +27,9 @@
27
  "emb_init_uniform_lim": null,
28
  "fan_mode": "fan_in",
29
  "init_div_is_residual": true,
30
- "init_gain": 0,
31
  "init_nonlinearity": "relu",
32
- "init_std": 0.02,
33
  "name": "kaiming_normal_",
34
  "verbose": 0
35
  },
@@ -45,7 +45,7 @@
45
  "resid_pdrop": 0,
46
  "tokenizer_name": "EleutherAI/gpt-neox-20b",
47
  "torch_dtype": "bfloat16",
48
- "transformers_version": "4.28.1",
49
  "use_cache": false,
50
  "verbose": 0,
51
  "vocab_size": 50432
 
27
  "emb_init_uniform_lim": null,
28
  "fan_mode": "fan_in",
29
  "init_div_is_residual": true,
30
+ "init_gain": 0.0,
31
  "init_nonlinearity": "relu",
32
+ "init_std": null,
33
  "name": "kaiming_normal_",
34
  "verbose": 0
35
  },
 
45
  "resid_pdrop": 0,
46
  "tokenizer_name": "EleutherAI/gpt-neox-20b",
47
  "torch_dtype": "bfloat16",
48
+ "transformers_version": "4.30.2",
49
  "use_cache": false,
50
  "verbose": 0,
51
  "vocab_size": 50432
generation_config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
  "_from_model_config": true,
3
- "transformers_version": "4.28.1",
4
  "use_cache": false
5
  }
 
1
  {
2
  "_from_model_config": true,
3
+ "transformers_version": "4.30.2",
4
  "use_cache": false
5
  }
modeling_mpt.py CHANGED
@@ -18,7 +18,7 @@ from .configuration_mpt import MPTConfig
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
19
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
20
  from .meta_init_context import init_empty_weights
21
- from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
22
  try:
23
  from .flash_attn_triton import flash_attn_func
24
  except:
@@ -80,7 +80,7 @@ class MPTModel(MPTPreTrainedModel):
80
  def get_input_embeddings(self):
81
  return self.wte
82
 
83
- def set_input_embeddings(self, value):
84
  self.wte = value
85
 
86
  @torch.no_grad()
@@ -140,7 +140,7 @@ class MPTModel(MPTPreTrainedModel):
140
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
141
  return attn_bias
142
 
143
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
144
  return_dict = return_dict if return_dict is not None else self.config.return_dict
145
  use_cache = use_cache if use_cache is not None else self.config.use_cache
146
  if attention_mask is not None:
@@ -156,6 +156,8 @@ class MPTModel(MPTPreTrainedModel):
156
  raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
 
159
  if self.training:
160
  if self.attn_uses_sequence_id and sequence_id is None:
161
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
@@ -225,7 +227,8 @@ class MPTForCausalLM(MPTPreTrainedModel):
225
  super().__init__(config)
226
  if not config.tie_word_embeddings:
227
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
228
- self.transformer = MPTModel(config)
 
229
  for child in self.transformer.children():
230
  if isinstance(child, torch.nn.ModuleList):
231
  continue
@@ -259,9 +262,11 @@ class MPTForCausalLM(MPTPreTrainedModel):
259
  def get_decoder(self):
260
  return self.transformer
261
 
262
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
263
  return_dict = return_dict if return_dict is not None else self.config.return_dict
264
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
265
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
266
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
267
  if self.logit_scale is not None:
@@ -270,9 +275,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
270
  logits *= self.logit_scale
271
  loss = None
272
  if labels is not None:
273
- labels = torch.roll(labels, shifts=-1)
274
- labels[:, -1] = -100
275
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
276
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
277
 
278
  def param_init_fn(self, module):
 
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
19
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
20
  from .meta_init_context import init_empty_weights
21
+ from .param_init_fns import generic_param_init_fn_, MODEL_INIT_REGISTRY
22
  try:
23
  from .flash_attn_triton import flash_attn_func
24
  except:
 
80
  def get_input_embeddings(self):
81
  return self.wte
82
 
83
+ def set_input_embeddings(self, value: nn.Embedding):
84
  self.wte = value
85
 
86
  @torch.no_grad()
 
140
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
141
  return attn_bias
142
 
143
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None):
144
  return_dict = return_dict if return_dict is not None else self.config.return_dict
145
  use_cache = use_cache if use_cache is not None else self.config.use_cache
146
  if attention_mask is not None:
 
156
  raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
159
+ if inputs_embeds is not None:
160
+ raise NotImplementedError('inputs_embeds is not implemented for MPT.')
161
  if self.training:
162
  if self.attn_uses_sequence_id and sequence_id is None:
163
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
 
227
  super().__init__(config)
228
  if not config.tie_word_embeddings:
229
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
230
+ print(f'Instantiating an MPTForCausalLM model from {__file__}')
231
+ self.transformer: MPTModel = MPTModel(config)
232
  for child in self.transformer.children():
233
  if isinstance(child, torch.nn.ModuleList):
234
  continue
 
262
  def get_decoder(self):
263
  return self.transformer
264
 
265
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None):
266
  return_dict = return_dict if return_dict is not None else self.config.return_dict
267
  use_cache = use_cache if use_cache is not None else self.config.use_cache
268
+ if inputs_embeds is not None:
269
+ raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
270
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
271
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
272
  if self.logit_scale is not None:
 
275
  logits *= self.logit_scale
276
  loss = None
277
  if labels is not None:
278
+ _labels = torch.roll(labels, shifts=-1)
279
+ _labels[:, -1] = -100
280
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
281
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
282
 
283
  def param_init_fn(self, module):
norm.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
 
3
  def _cast_if_autocast_enabled(tensor):
@@ -25,7 +26,7 @@ class LPLayerNorm(torch.nn.LayerNorm):
25
  return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26
 
27
  def rms_norm(x, weight=None, eps=1e-05):
28
- output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29
  if weight is not None:
30
  return output * weight
31
  return output
@@ -53,4 +54,4 @@ class LPRMSNorm(RMSNorm):
53
  downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
54
  with torch.autocast(enabled=False, device_type=x.device.type):
55
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56
- NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
 
1
+ from typing import Dict, Type
2
  import torch
3
 
4
  def _cast_if_autocast_enabled(tensor):
 
26
  return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
27
 
28
  def rms_norm(x, weight=None, eps=1e-05):
29
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
30
  if weight is not None:
31
  return output * weight
32
  return output
 
54
  downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
55
  with torch.autocast(enabled=False, device_type=x.device.type):
56
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
57
+ NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
pytorch_model-00001-of-00002.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4e96cf543cf2dbb5579abe2ca1f69e75ed159ff5d3cbad4b5fd406617d80ef44
3
  size 9943040275
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6003cd1c33b5a661320c11225b54fb0cdfd931f73241ed810c57dc9e32163146
3
  size 9943040275
pytorch_model-00002-of-00002.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f5fb9462a15a43819a0e2dd63faef50cea728f78d3de37721bcd2efe0d43439
3
  size 3355599187
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:234b5d739ed88a00dcf1e28932158157418d386837d2345f0ec8a0b218e7d823
3
  size 3355599187