oldjar07 commited on
Commit
233dd28
·
verified ·
1 Parent(s): ccbe911

Upload altered non lfs files from Github branch

Browse files
Files changed (7) hide show
  1. attention.py +12 -12
  2. blocks.py +5 -5
  3. config.json +1 -1
  4. configuration_mpt.py +7 -7
  5. ffn.py +6 -6
  6. modeling_mpt.py +8 -8
  7. param_init_fns.py +20 -20
attention.py CHANGED
@@ -188,15 +188,15 @@ class GroupedQueryAttention(nn.Module):
188
  implementation enables user to also use additive bias.
189
  """
190
 
191
- def __init__(self, d_model: int, n_heads: int, kv_n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None):
192
  super().__init__()
193
  self.attn_impl = attn_impl
194
  self.clip_qkv = clip_qkv
195
  self.qk_ln = qk_ln
196
- self.d_model = d_model
197
  self.n_heads = n_heads
198
  self.kv_n_heads = kv_n_heads
199
- self.head_dim = d_model // n_heads
200
  if self.kv_n_heads <= 0:
201
  raise ValueError('kv_n_heads should be greater than zero.')
202
  if self.kv_n_heads > self.n_heads:
@@ -205,17 +205,17 @@ class GroupedQueryAttention(nn.Module):
205
  raise ValueError('Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.')
206
  self.softmax_scale = softmax_scale
207
  if self.softmax_scale is None:
208
- self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
209
  self.attn_dropout_p = attn_pdrop
210
  fc_kwargs = {}
211
  if fc_type != 'te':
212
  fc_kwargs['device'] = device
213
- self.Wqkv = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model + 2 * self.kv_n_heads * self.head_dim, **fc_kwargs)
214
  fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
215
  self.Wqkv._fused = (0, fuse_splits)
216
  if self.qk_ln:
217
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
218
- self.q_ln = norm_class(self.d_model, device=device)
219
  self.k_ln = norm_class(self.kv_n_heads * self.head_dim, device=device)
220
  if self.attn_impl == 'flash':
221
  self.attn_fn = flash_attn_fn
@@ -225,14 +225,14 @@ class GroupedQueryAttention(nn.Module):
225
  self.attn_fn = scaled_multihead_dot_product_attention
226
  else:
227
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
228
- self.out_proj = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model, **fc_kwargs)
229
  self.out_proj._is_residual = True
230
 
231
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
232
  qkv = self.Wqkv(x)
233
  if self.clip_qkv:
234
  qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
235
- (query, key, value) = qkv.split([self.d_model, self.kv_n_heads * self.head_dim, self.kv_n_heads * self.head_dim], dim=2)
236
  key_padding_mask = attention_mask
237
  if self.qk_ln:
238
  dtype = query.dtype
@@ -248,8 +248,8 @@ class MultiheadAttention(GroupedQueryAttention):
248
  additive bias.
249
  """
250
 
251
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None):
252
- super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device)
253
 
254
  class MultiQueryAttention(GroupedQueryAttention):
255
  """Multi-Query self attention.
@@ -258,8 +258,8 @@ class MultiQueryAttention(GroupedQueryAttention):
258
  additive bias.
259
  """
260
 
261
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None):
262
- super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device)
263
 
264
  def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
265
  if attn_impl == 'flash':
 
188
  implementation enables user to also use additive bias.
189
  """
190
 
191
+ def __init__(self, hidden_size: int, n_heads: int, kv_n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None):
192
  super().__init__()
193
  self.attn_impl = attn_impl
194
  self.clip_qkv = clip_qkv
195
  self.qk_ln = qk_ln
196
+ self.hidden_size = hidden_size
197
  self.n_heads = n_heads
198
  self.kv_n_heads = kv_n_heads
199
+ self.head_dim = hidden_size // n_heads
200
  if self.kv_n_heads <= 0:
201
  raise ValueError('kv_n_heads should be greater than zero.')
202
  if self.kv_n_heads > self.n_heads:
 
205
  raise ValueError('Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.')
206
  self.softmax_scale = softmax_scale
207
  if self.softmax_scale is None:
208
+ self.softmax_scale = 1 / math.sqrt(self.hidden_size / self.n_heads)
209
  self.attn_dropout_p = attn_pdrop
210
  fc_kwargs = {}
211
  if fc_type != 'te':
212
  fc_kwargs['device'] = device
213
+ self.Wqkv = FC_CLASS_REGISTRY[fc_type](self.hidden_size, self.hidden_size + 2 * self.kv_n_heads * self.head_dim, **fc_kwargs)
214
  fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
215
  self.Wqkv._fused = (0, fuse_splits)
216
  if self.qk_ln:
217
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
218
+ self.q_ln = norm_class(self.hidden_size, device=device)
219
  self.k_ln = norm_class(self.kv_n_heads * self.head_dim, device=device)
220
  if self.attn_impl == 'flash':
221
  self.attn_fn = flash_attn_fn
 
225
  self.attn_fn = scaled_multihead_dot_product_attention
226
  else:
227
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
228
+ self.out_proj = FC_CLASS_REGISTRY[fc_type](self.hidden_size, self.hidden_size, **fc_kwargs)
229
  self.out_proj._is_residual = True
230
 
231
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
232
  qkv = self.Wqkv(x)
233
  if self.clip_qkv:
234
  qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
235
+ (query, key, value) = qkv.split([self.hidden_size, self.kv_n_heads * self.head_dim, self.kv_n_heads * self.head_dim], dim=2)
236
  key_padding_mask = attention_mask
237
  if self.qk_ln:
238
  dtype = query.dtype
 
248
  additive bias.
249
  """
250
 
251
+ def __init__(self, hidden_size: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None):
252
+ super().__init__(hidden_size=hidden_size, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device)
253
 
254
  class MultiQueryAttention(GroupedQueryAttention):
255
  """Multi-Query self attention.
 
258
  additive bias.
259
  """
260
 
261
+ def __init__(self, hidden_size: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None):
262
+ super().__init__(hidden_size=hidden_size, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device)
263
 
264
  def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
265
  if attn_impl == 'flash':
blocks.py CHANGED
@@ -8,7 +8,7 @@ from .norm import NORM_CLASS_REGISTRY
8
 
9
  class MPTBlock(nn.Module):
10
 
11
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Optional[Dict]=None, ffn_config: Optional[Dict]=None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, **kwargs: Any):
12
  if attn_config is None:
13
  attn_config = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
14
  if ffn_config is None:
@@ -20,12 +20,12 @@ class MPTBlock(nn.Module):
20
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
21
  args_to_exclude_in_attn_class = {'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max'}
22
  attn_config_subset_for_attn_class = {k: v for (k, v) in attn_config.items() if k not in args_to_exclude_in_attn_class}
23
- self.norm_1 = norm_class(d_model, device=device)
24
- self.attn = attn_class(d_model=d_model, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class)
25
  self.norm_2 = None
26
  if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False):
27
- self.norm_2 = norm_class(d_model, device=device)
28
- self.ffn = build_ffn(d_model=d_model, expansion_ratio=expansion_ratio, device=device, **ffn_config)
29
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
30
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
31
 
 
8
 
9
  class MPTBlock(nn.Module):
10
 
11
+ def __init__(self, hidden_size: int, n_heads: int, expansion_ratio: int, attn_config: Optional[Dict]=None, ffn_config: Optional[Dict]=None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, **kwargs: Any):
12
  if attn_config is None:
13
  attn_config = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
14
  if ffn_config is None:
 
20
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
21
  args_to_exclude_in_attn_class = {'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max'}
22
  attn_config_subset_for_attn_class = {k: v for (k, v) in attn_config.items() if k not in args_to_exclude_in_attn_class}
23
+ self.norm_1 = norm_class(hidden_size, device=device)
24
+ self.attn = attn_class(hidden_size=hidden_size, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class)
25
  self.norm_2 = None
26
  if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False):
27
+ self.norm_2 = norm_class(hidden_size, device=device)
28
+ self.ffn = build_ffn(hidden_size=hidden_size, expansion_ratio=expansion_ratio, device=device, **ffn_config)
29
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
30
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
31
 
config.json CHANGED
@@ -19,7 +19,7 @@
19
  "AutoConfig": "configuration_mpt.MPTConfig",
20
  "AutoModelForCausalLM": "modeling_mpt.MPTForCausalLM"
21
  },
22
- "d_model": 3072,
23
  "emb_pdrop": 0.0,
24
  "embedding_fraction": 1.0,
25
  "expansion_ratio": 4,
 
19
  "AutoConfig": "configuration_mpt.MPTConfig",
20
  "AutoModelForCausalLM": "modeling_mpt.MPTForCausalLM"
21
  },
22
+ "hidden_size": 3072,
23
  "emb_pdrop": 0.0,
24
  "embedding_fraction": 1.0,
25
  "expansion_ratio": 4,
configuration_mpt.py CHANGED
@@ -9,11 +9,11 @@ init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', '
9
  class MPTConfig(PretrainedConfig):
10
  model_type = 'mpt'
11
 
12
- def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, ffn_config: Dict=ffn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, fc_type: str='torch', verbose: Optional[int]=None, **kwargs: Any):
13
  """The MPT configuration class.
14
 
15
  Args:
16
- d_model (int): The size of the embedding dimension of the model.
17
  n_heads (int): The number of attention heads.
18
  n_layers (int): The number of layers in the model.
19
  expansion_ratio (int): The ratio of the up/down scale in the ffn.
@@ -67,7 +67,7 @@ class MPTConfig(PretrainedConfig):
67
  See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
68
  fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
69
  """
70
- self.d_model = d_model
71
  self.n_heads = n_heads
72
  self.n_layers = n_layers
73
  self.expansion_ratio = expansion_ratio
@@ -108,8 +108,8 @@ class MPTConfig(PretrainedConfig):
108
  self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
109
  self.ffn_config = self._set_config_defaults(self.ffn_config, ffn_config_defaults)
110
  self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
111
- if self.d_model % self.n_heads != 0:
112
- raise ValueError('d_model must be divisible by n_heads')
113
  if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
114
  raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
115
  if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
@@ -122,8 +122,8 @@ class MPTConfig(PretrainedConfig):
122
  raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
123
  if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
124
  raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
125
- if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
126
- raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
127
  if self.init_config.get('name', None) is None:
128
  raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
129
  if not self.learned_pos_emb and (not self.attn_config['alibi']):
 
9
  class MPTConfig(PretrainedConfig):
10
  model_type = 'mpt'
11
 
12
+ def __init__(self, hidden_size: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, ffn_config: Dict=ffn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, fc_type: str='torch', verbose: Optional[int]=None, **kwargs: Any):
13
  """The MPT configuration class.
14
 
15
  Args:
16
+ hidden_size (int): The size of the embedding dimension of the model.
17
  n_heads (int): The number of attention heads.
18
  n_layers (int): The number of layers in the model.
19
  expansion_ratio (int): The ratio of the up/down scale in the ffn.
 
67
  See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
68
  fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
69
  """
70
+ self.hidden_size = hidden_size
71
  self.n_heads = n_heads
72
  self.n_layers = n_layers
73
  self.expansion_ratio = expansion_ratio
 
108
  self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
109
  self.ffn_config = self._set_config_defaults(self.ffn_config, ffn_config_defaults)
110
  self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
111
+ if self.hidden_size % self.n_heads != 0:
112
+ raise ValueError('hidden_size must be divisible by n_heads')
113
  if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
114
  raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
115
  if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
 
122
  raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
123
  if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
124
  raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
125
+ if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_hidden_size':
126
+ raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_hidden_size'.")
127
  if self.init_config.get('name', None) is None:
128
  raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
129
  if not self.learned_pos_emb and (not self.attn_config['alibi']):
ffn.py CHANGED
@@ -10,14 +10,14 @@ except:
10
 
11
  class MPTMLP(nn.Module):
12
 
13
- def __init__(self, d_model: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None):
14
  super().__init__()
15
  fc_kwargs = {}
16
  if fc_type != 'te':
17
  fc_kwargs['device'] = device
18
- self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model, expansion_ratio * d_model, **fc_kwargs)
19
  self.act = nn.GELU(approximate='none')
20
- self.down_proj = FC_CLASS_REGISTRY[fc_type](expansion_ratio * d_model, d_model, **fc_kwargs)
21
  self.down_proj._is_residual = True
22
 
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -27,13 +27,13 @@ if te is not None:
27
  te.LayerNormMLP._has_norm = True
28
  FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP
29
 
30
- def build_ffn(d_model: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None, **kwargs: Any) -> nn.Module:
31
  ffn_type = kwargs.pop('ffn_type')
32
  if ffn_type == 'mptmlp':
33
  if len(kwargs) > 0:
34
  raise ValueError(f'MPTMLP got an unexpected keyword argument: {kwargs}')
35
- return MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, device=device)
36
  elif ffn_type == 'te_ln_mlp':
37
  assert te is not None
38
- return te.LayerNormMLP(hidden_size=d_model, ffn_hidden_size=d_model * expansion_ratio, **kwargs)
39
  raise ValueError(f'ffn_type={ffn_type!r} not recognized.')
 
10
 
11
  class MPTMLP(nn.Module):
12
 
13
+ def __init__(self, hidden_size: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None):
14
  super().__init__()
15
  fc_kwargs = {}
16
  if fc_type != 'te':
17
  fc_kwargs['device'] = device
18
+ self.up_proj = FC_CLASS_REGISTRY[fc_type](hidden_size, expansion_ratio * hidden_size, **fc_kwargs)
19
  self.act = nn.GELU(approximate='none')
20
+ self.down_proj = FC_CLASS_REGISTRY[fc_type](expansion_ratio * hidden_size, hidden_size, **fc_kwargs)
21
  self.down_proj._is_residual = True
22
 
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
27
  te.LayerNormMLP._has_norm = True
28
  FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP
29
 
30
+ def build_ffn(hidden_size: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None, **kwargs: Any) -> nn.Module:
31
  ffn_type = kwargs.pop('ffn_type')
32
  if ffn_type == 'mptmlp':
33
  if len(kwargs) > 0:
34
  raise ValueError(f'MPTMLP got an unexpected keyword argument: {kwargs}')
35
+ return MPTMLP(hidden_size=hidden_size, expansion_ratio=expansion_ratio, fc_type=fc_type, device=device)
36
  elif ffn_type == 'te_ln_mlp':
37
  assert te is not None
38
+ return te.LayerNormMLP(hidden_size=hidden_size, ffn_hidden_size=hidden_size * expansion_ratio, **kwargs)
39
  raise ValueError(f'ffn_type={ffn_type!r} not recognized.')
modeling_mpt.py CHANGED
@@ -56,12 +56,12 @@ class MPTModel(MPTPreTrainedModel):
56
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
57
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
58
  self.embedding_fraction = config.embedding_fraction
59
- self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
60
  if self.learned_pos_emb:
61
- self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
62
  self.emb_drop = nn.Dropout(config.emb_pdrop)
63
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
64
- self.norm_f = norm_class(config.d_model, device=config.init_device)
65
  if config.init_device != 'meta':
66
  log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
67
  self.apply(self.param_init_fn)
@@ -213,7 +213,7 @@ class MPTModel(MPTPreTrainedModel):
213
 
214
  def param_init_fn(self, module: nn.Module) -> None:
215
  init_fn_name = self.config.init_config['name']
216
- MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
217
 
218
  def fsdp_wrap_fn(self, module: nn.Module) -> bool:
219
  return isinstance(module, MPTBlock)
@@ -238,10 +238,10 @@ class MPTForCausalLM(MPTPreTrainedModel):
238
  if config.logit_scale is not None:
239
  logit_scale = config.logit_scale
240
  if isinstance(logit_scale, str):
241
- if logit_scale == 'inv_sqrt_d_model':
242
- logit_scale = 1 / math.sqrt(config.d_model)
243
  else:
244
- raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
245
  self.logit_scale = logit_scale
246
 
247
  def get_input_embeddings(self) -> nn.Embedding:
@@ -282,7 +282,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
282
 
283
  def param_init_fn(self, module: nn.Module) -> None:
284
  init_fn_name = self.config.init_config['name']
285
- MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
286
 
287
  def fsdp_wrap_fn(self, module: nn.Module) -> bool:
288
  return isinstance(module, MPTBlock)
 
56
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
57
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
58
  self.embedding_fraction = config.embedding_fraction
59
+ self.wte = SharedEmbedding(config.vocab_size, config.hidden_size, device=config.init_device)
60
  if self.learned_pos_emb:
61
+ self.wpe = torch.nn.Embedding(config.max_seq_len, config.hidden_size, device=config.init_device)
62
  self.emb_drop = nn.Dropout(config.emb_pdrop)
63
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
64
+ self.norm_f = norm_class(config.hidden_size, device=config.init_device)
65
  if config.init_device != 'meta':
66
  log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
67
  self.apply(self.param_init_fn)
 
213
 
214
  def param_init_fn(self, module: nn.Module) -> None:
215
  init_fn_name = self.config.init_config['name']
216
+ MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, hidden_size=self.config.hidden_size, **self.config.init_config)
217
 
218
  def fsdp_wrap_fn(self, module: nn.Module) -> bool:
219
  return isinstance(module, MPTBlock)
 
238
  if config.logit_scale is not None:
239
  logit_scale = config.logit_scale
240
  if isinstance(logit_scale, str):
241
+ if logit_scale == 'inv_sqrt_hidden_size':
242
+ logit_scale = 1 / math.sqrt(config.hidden_size)
243
  else:
244
+ raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_hidden_size'.")
245
  self.logit_scale = logit_scale
246
 
247
  def get_input_embeddings(self) -> nn.Embedding:
 
282
 
283
  def param_init_fn(self, module: nn.Module) -> None:
284
  init_fn_name = self.config.init_config['name']
285
+ MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, hidden_size=self.config.hidden_size, **self.config.init_config)
286
 
287
  def fsdp_wrap_fn(self, module: nn.Module) -> bool:
288
  return isinstance(module, MPTBlock)
param_init_fns.py CHANGED
@@ -29,7 +29,7 @@ def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
29
  slice_indices[dim] = slice(s, e)
30
  init_fn_(module.weight[slice_indices])
31
 
32
- def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
33
  del kwargs
34
  init_div_is_residual = init_div_is_residual
35
  if init_div_is_residual is False:
@@ -85,8 +85,8 @@ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
85
  if module._qkv_same_embed_dim:
86
  assert module.in_proj_weight is not None
87
  assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
88
- assert d_model is not None
89
- _d = d_model
90
  splits = (0, _d, 2 * _d, 3 * _d)
91
  for (s, e) in zip(splits[:-1], splits[1:]):
92
  init_fn_(module.in_proj_weight[s:e])
@@ -130,23 +130,23 @@ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
130
  def _normal_init_(std: float, mean: float=0.0) -> Callable:
131
  return partial(torch.nn.init.normal_, mean=mean, std=std)
132
 
133
- def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
134
  del kwargs
135
  init_fn_ = _normal_init_(std=std)
136
- generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
137
 
138
- def baseline_param_init_fn_(module: nn.Module, init_std: Optional[float], n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
139
  del kwargs
140
  if init_std is None:
141
  raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
142
- _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
143
 
144
- def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
145
  del kwargs
146
- std = math.sqrt(2 / (5 * d_model))
147
- _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
148
 
149
- def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
150
  """From section 2.3.1 of GPT-NeoX-20B:
151
 
152
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
@@ -155,25 +155,25 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init
155
  """
156
  del kwargs
157
  residual_div = n_layers / math.sqrt(10)
158
- small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
159
 
160
- def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
161
  del kwargs
162
  kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
163
- generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
164
 
165
- def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
166
  del kwargs
167
  kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
168
- generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
169
 
170
- def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
171
  del kwargs
172
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
173
- generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
174
 
175
- def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
176
  del kwargs
177
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
- generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
179
  MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
 
29
  slice_indices[dim] = slice(s, e)
30
  init_fn_(module.weight[slice_indices])
31
 
32
+ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
33
  del kwargs
34
  init_div_is_residual = init_div_is_residual
35
  if init_div_is_residual is False:
 
85
  if module._qkv_same_embed_dim:
86
  assert module.in_proj_weight is not None
87
  assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
88
+ assert hidden_size is not None
89
+ _d = hidden_size
90
  splits = (0, _d, 2 * _d, 3 * _d)
91
  for (s, e) in zip(splits[:-1], splits[1:]):
92
  init_fn_(module.in_proj_weight[s:e])
 
130
  def _normal_init_(std: float, mean: float=0.0) -> Callable:
131
  return partial(torch.nn.init.normal_, mean=mean, std=std)
132
 
133
+ def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
134
  del kwargs
135
  init_fn_ = _normal_init_(std=std)
136
+ generic_param_init_fn_(module=module, init_fn_=init_fn_, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
137
 
138
+ def baseline_param_init_fn_(module: nn.Module, init_std: Optional[float], n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
139
  del kwargs
140
  if init_std is None:
141
  raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
142
+ _normal_param_init_fn_(module=module, std=init_std, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
143
 
144
+ def small_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
145
  del kwargs
146
+ std = math.sqrt(2 / (5 * hidden_size))
147
+ _normal_param_init_fn_(module=module, std=std, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
148
 
149
+ def neox_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
150
  """From section 2.3.1 of GPT-NeoX-20B:
151
 
152
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
 
155
  """
156
  del kwargs
157
  residual_div = n_layers / math.sqrt(10)
158
+ small_param_init_fn_(module=module, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
159
 
160
+ def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
161
  del kwargs
162
  kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
163
+ generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
164
 
165
+ def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
166
  del kwargs
167
  kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
168
+ generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
169
 
170
+ def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
171
  del kwargs
172
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
173
+ generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
174
 
175
+ def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
176
  del kwargs
177
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
+ generic_param_init_fn_(module=module, init_fn_=xavier_normal_, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
179
  MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}