Upload altered non lfs files from Github branch
Browse files- attention.py +12 -12
- blocks.py +5 -5
- config.json +1 -1
- configuration_mpt.py +7 -7
- ffn.py +6 -6
- modeling_mpt.py +8 -8
- param_init_fns.py +20 -20
attention.py
CHANGED
@@ -188,15 +188,15 @@ class GroupedQueryAttention(nn.Module):
|
|
188 |
implementation enables user to also use additive bias.
|
189 |
"""
|
190 |
|
191 |
-
def __init__(self,
|
192 |
super().__init__()
|
193 |
self.attn_impl = attn_impl
|
194 |
self.clip_qkv = clip_qkv
|
195 |
self.qk_ln = qk_ln
|
196 |
-
self.
|
197 |
self.n_heads = n_heads
|
198 |
self.kv_n_heads = kv_n_heads
|
199 |
-
self.head_dim =
|
200 |
if self.kv_n_heads <= 0:
|
201 |
raise ValueError('kv_n_heads should be greater than zero.')
|
202 |
if self.kv_n_heads > self.n_heads:
|
@@ -205,17 +205,17 @@ class GroupedQueryAttention(nn.Module):
|
|
205 |
raise ValueError('Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.')
|
206 |
self.softmax_scale = softmax_scale
|
207 |
if self.softmax_scale is None:
|
208 |
-
self.softmax_scale = 1 / math.sqrt(self.
|
209 |
self.attn_dropout_p = attn_pdrop
|
210 |
fc_kwargs = {}
|
211 |
if fc_type != 'te':
|
212 |
fc_kwargs['device'] = device
|
213 |
-
self.Wqkv = FC_CLASS_REGISTRY[fc_type](self.
|
214 |
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
|
215 |
self.Wqkv._fused = (0, fuse_splits)
|
216 |
if self.qk_ln:
|
217 |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
218 |
-
self.q_ln = norm_class(self.
|
219 |
self.k_ln = norm_class(self.kv_n_heads * self.head_dim, device=device)
|
220 |
if self.attn_impl == 'flash':
|
221 |
self.attn_fn = flash_attn_fn
|
@@ -225,14 +225,14 @@ class GroupedQueryAttention(nn.Module):
|
|
225 |
self.attn_fn = scaled_multihead_dot_product_attention
|
226 |
else:
|
227 |
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
228 |
-
self.out_proj = FC_CLASS_REGISTRY[fc_type](self.
|
229 |
self.out_proj._is_residual = True
|
230 |
|
231 |
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
232 |
qkv = self.Wqkv(x)
|
233 |
if self.clip_qkv:
|
234 |
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
235 |
-
(query, key, value) = qkv.split([self.
|
236 |
key_padding_mask = attention_mask
|
237 |
if self.qk_ln:
|
238 |
dtype = query.dtype
|
@@ -248,8 +248,8 @@ class MultiheadAttention(GroupedQueryAttention):
|
|
248 |
additive bias.
|
249 |
"""
|
250 |
|
251 |
-
def __init__(self,
|
252 |
-
super().__init__(
|
253 |
|
254 |
class MultiQueryAttention(GroupedQueryAttention):
|
255 |
"""Multi-Query self attention.
|
@@ -258,8 +258,8 @@ class MultiQueryAttention(GroupedQueryAttention):
|
|
258 |
additive bias.
|
259 |
"""
|
260 |
|
261 |
-
def __init__(self,
|
262 |
-
super().__init__(
|
263 |
|
264 |
def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
|
265 |
if attn_impl == 'flash':
|
|
|
188 |
implementation enables user to also use additive bias.
|
189 |
"""
|
190 |
|
191 |
+
def __init__(self, hidden_size: int, n_heads: int, kv_n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None):
|
192 |
super().__init__()
|
193 |
self.attn_impl = attn_impl
|
194 |
self.clip_qkv = clip_qkv
|
195 |
self.qk_ln = qk_ln
|
196 |
+
self.hidden_size = hidden_size
|
197 |
self.n_heads = n_heads
|
198 |
self.kv_n_heads = kv_n_heads
|
199 |
+
self.head_dim = hidden_size // n_heads
|
200 |
if self.kv_n_heads <= 0:
|
201 |
raise ValueError('kv_n_heads should be greater than zero.')
|
202 |
if self.kv_n_heads > self.n_heads:
|
|
|
205 |
raise ValueError('Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.')
|
206 |
self.softmax_scale = softmax_scale
|
207 |
if self.softmax_scale is None:
|
208 |
+
self.softmax_scale = 1 / math.sqrt(self.hidden_size / self.n_heads)
|
209 |
self.attn_dropout_p = attn_pdrop
|
210 |
fc_kwargs = {}
|
211 |
if fc_type != 'te':
|
212 |
fc_kwargs['device'] = device
|
213 |
+
self.Wqkv = FC_CLASS_REGISTRY[fc_type](self.hidden_size, self.hidden_size + 2 * self.kv_n_heads * self.head_dim, **fc_kwargs)
|
214 |
fuse_splits = [i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)]
|
215 |
self.Wqkv._fused = (0, fuse_splits)
|
216 |
if self.qk_ln:
|
217 |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
218 |
+
self.q_ln = norm_class(self.hidden_size, device=device)
|
219 |
self.k_ln = norm_class(self.kv_n_heads * self.head_dim, device=device)
|
220 |
if self.attn_impl == 'flash':
|
221 |
self.attn_fn = flash_attn_fn
|
|
|
225 |
self.attn_fn = scaled_multihead_dot_product_attention
|
226 |
else:
|
227 |
raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
|
228 |
+
self.out_proj = FC_CLASS_REGISTRY[fc_type](self.hidden_size, self.hidden_size, **fc_kwargs)
|
229 |
self.out_proj._is_residual = True
|
230 |
|
231 |
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
232 |
qkv = self.Wqkv(x)
|
233 |
if self.clip_qkv:
|
234 |
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
235 |
+
(query, key, value) = qkv.split([self.hidden_size, self.kv_n_heads * self.head_dim, self.kv_n_heads * self.head_dim], dim=2)
|
236 |
key_padding_mask = attention_mask
|
237 |
if self.qk_ln:
|
238 |
dtype = query.dtype
|
|
|
248 |
additive bias.
|
249 |
"""
|
250 |
|
251 |
+
def __init__(self, hidden_size: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None):
|
252 |
+
super().__init__(hidden_size=hidden_size, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device)
|
253 |
|
254 |
class MultiQueryAttention(GroupedQueryAttention):
|
255 |
"""Multi-Query self attention.
|
|
|
258 |
additive bias.
|
259 |
"""
|
260 |
|
261 |
+
def __init__(self, hidden_size: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None):
|
262 |
+
super().__init__(hidden_size=hidden_size, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device)
|
263 |
|
264 |
def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
|
265 |
if attn_impl == 'flash':
|
blocks.py
CHANGED
@@ -8,7 +8,7 @@ from .norm import NORM_CLASS_REGISTRY
|
|
8 |
|
9 |
class MPTBlock(nn.Module):
|
10 |
|
11 |
-
def __init__(self,
|
12 |
if attn_config is None:
|
13 |
attn_config = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
|
14 |
if ffn_config is None:
|
@@ -20,12 +20,12 @@ class MPTBlock(nn.Module):
|
|
20 |
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
|
21 |
args_to_exclude_in_attn_class = {'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max'}
|
22 |
attn_config_subset_for_attn_class = {k: v for (k, v) in attn_config.items() if k not in args_to_exclude_in_attn_class}
|
23 |
-
self.norm_1 = norm_class(
|
24 |
-
self.attn = attn_class(
|
25 |
self.norm_2 = None
|
26 |
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False):
|
27 |
-
self.norm_2 = norm_class(
|
28 |
-
self.ffn = build_ffn(
|
29 |
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
30 |
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
31 |
|
|
|
8 |
|
9 |
class MPTBlock(nn.Module):
|
10 |
|
11 |
+
def __init__(self, hidden_size: int, n_heads: int, expansion_ratio: int, attn_config: Optional[Dict]=None, ffn_config: Optional[Dict]=None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, **kwargs: Any):
|
12 |
if attn_config is None:
|
13 |
attn_config = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
|
14 |
if ffn_config is None:
|
|
|
20 |
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
|
21 |
args_to_exclude_in_attn_class = {'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max'}
|
22 |
attn_config_subset_for_attn_class = {k: v for (k, v) in attn_config.items() if k not in args_to_exclude_in_attn_class}
|
23 |
+
self.norm_1 = norm_class(hidden_size, device=device)
|
24 |
+
self.attn = attn_class(hidden_size=hidden_size, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class)
|
25 |
self.norm_2 = None
|
26 |
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False):
|
27 |
+
self.norm_2 = norm_class(hidden_size, device=device)
|
28 |
+
self.ffn = build_ffn(hidden_size=hidden_size, expansion_ratio=expansion_ratio, device=device, **ffn_config)
|
29 |
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
30 |
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
31 |
|
config.json
CHANGED
@@ -19,7 +19,7 @@
|
|
19 |
"AutoConfig": "configuration_mpt.MPTConfig",
|
20 |
"AutoModelForCausalLM": "modeling_mpt.MPTForCausalLM"
|
21 |
},
|
22 |
-
"
|
23 |
"emb_pdrop": 0.0,
|
24 |
"embedding_fraction": 1.0,
|
25 |
"expansion_ratio": 4,
|
|
|
19 |
"AutoConfig": "configuration_mpt.MPTConfig",
|
20 |
"AutoModelForCausalLM": "modeling_mpt.MPTForCausalLM"
|
21 |
},
|
22 |
+
"hidden_size": 3072,
|
23 |
"emb_pdrop": 0.0,
|
24 |
"embedding_fraction": 1.0,
|
25 |
"expansion_ratio": 4,
|
configuration_mpt.py
CHANGED
@@ -9,11 +9,11 @@ init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', '
|
|
9 |
class MPTConfig(PretrainedConfig):
|
10 |
model_type = 'mpt'
|
11 |
|
12 |
-
def __init__(self,
|
13 |
"""The MPT configuration class.
|
14 |
|
15 |
Args:
|
16 |
-
|
17 |
n_heads (int): The number of attention heads.
|
18 |
n_layers (int): The number of layers in the model.
|
19 |
expansion_ratio (int): The ratio of the up/down scale in the ffn.
|
@@ -67,7 +67,7 @@ class MPTConfig(PretrainedConfig):
|
|
67 |
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
|
68 |
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
|
69 |
"""
|
70 |
-
self.
|
71 |
self.n_heads = n_heads
|
72 |
self.n_layers = n_layers
|
73 |
self.expansion_ratio = expansion_ratio
|
@@ -108,8 +108,8 @@ class MPTConfig(PretrainedConfig):
|
|
108 |
self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
|
109 |
self.ffn_config = self._set_config_defaults(self.ffn_config, ffn_config_defaults)
|
110 |
self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
|
111 |
-
if self.
|
112 |
-
raise ValueError('
|
113 |
if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
|
114 |
raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
|
115 |
if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
|
@@ -122,8 +122,8 @@ class MPTConfig(PretrainedConfig):
|
|
122 |
raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
|
123 |
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
|
124 |
raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
|
125 |
-
if isinstance(self.logit_scale, str) and self.logit_scale != '
|
126 |
-
raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or '
|
127 |
if self.init_config.get('name', None) is None:
|
128 |
raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
|
129 |
if not self.learned_pos_emb and (not self.attn_config['alibi']):
|
|
|
9 |
class MPTConfig(PretrainedConfig):
|
10 |
model_type = 'mpt'
|
11 |
|
12 |
+
def __init__(self, hidden_size: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, ffn_config: Dict=ffn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, fc_type: str='torch', verbose: Optional[int]=None, **kwargs: Any):
|
13 |
"""The MPT configuration class.
|
14 |
|
15 |
Args:
|
16 |
+
hidden_size (int): The size of the embedding dimension of the model.
|
17 |
n_heads (int): The number of attention heads.
|
18 |
n_layers (int): The number of layers in the model.
|
19 |
expansion_ratio (int): The ratio of the up/down scale in the ffn.
|
|
|
67 |
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
|
68 |
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
|
69 |
"""
|
70 |
+
self.hidden_size = hidden_size
|
71 |
self.n_heads = n_heads
|
72 |
self.n_layers = n_layers
|
73 |
self.expansion_ratio = expansion_ratio
|
|
|
108 |
self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
|
109 |
self.ffn_config = self._set_config_defaults(self.ffn_config, ffn_config_defaults)
|
110 |
self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
|
111 |
+
if self.hidden_size % self.n_heads != 0:
|
112 |
+
raise ValueError('hidden_size must be divisible by n_heads')
|
113 |
if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
|
114 |
raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
|
115 |
if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
|
|
|
122 |
raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
|
123 |
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
|
124 |
raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
|
125 |
+
if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_hidden_size':
|
126 |
+
raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_hidden_size'.")
|
127 |
if self.init_config.get('name', None) is None:
|
128 |
raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
|
129 |
if not self.learned_pos_emb and (not self.attn_config['alibi']):
|
ffn.py
CHANGED
@@ -10,14 +10,14 @@ except:
|
|
10 |
|
11 |
class MPTMLP(nn.Module):
|
12 |
|
13 |
-
def __init__(self,
|
14 |
super().__init__()
|
15 |
fc_kwargs = {}
|
16 |
if fc_type != 'te':
|
17 |
fc_kwargs['device'] = device
|
18 |
-
self.up_proj = FC_CLASS_REGISTRY[fc_type](
|
19 |
self.act = nn.GELU(approximate='none')
|
20 |
-
self.down_proj = FC_CLASS_REGISTRY[fc_type](expansion_ratio *
|
21 |
self.down_proj._is_residual = True
|
22 |
|
23 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -27,13 +27,13 @@ if te is not None:
|
|
27 |
te.LayerNormMLP._has_norm = True
|
28 |
FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP
|
29 |
|
30 |
-
def build_ffn(
|
31 |
ffn_type = kwargs.pop('ffn_type')
|
32 |
if ffn_type == 'mptmlp':
|
33 |
if len(kwargs) > 0:
|
34 |
raise ValueError(f'MPTMLP got an unexpected keyword argument: {kwargs}')
|
35 |
-
return MPTMLP(
|
36 |
elif ffn_type == 'te_ln_mlp':
|
37 |
assert te is not None
|
38 |
-
return te.LayerNormMLP(hidden_size=
|
39 |
raise ValueError(f'ffn_type={ffn_type!r} not recognized.')
|
|
|
10 |
|
11 |
class MPTMLP(nn.Module):
|
12 |
|
13 |
+
def __init__(self, hidden_size: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None):
|
14 |
super().__init__()
|
15 |
fc_kwargs = {}
|
16 |
if fc_type != 'te':
|
17 |
fc_kwargs['device'] = device
|
18 |
+
self.up_proj = FC_CLASS_REGISTRY[fc_type](hidden_size, expansion_ratio * hidden_size, **fc_kwargs)
|
19 |
self.act = nn.GELU(approximate='none')
|
20 |
+
self.down_proj = FC_CLASS_REGISTRY[fc_type](expansion_ratio * hidden_size, hidden_size, **fc_kwargs)
|
21 |
self.down_proj._is_residual = True
|
22 |
|
23 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
27 |
te.LayerNormMLP._has_norm = True
|
28 |
FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP
|
29 |
|
30 |
+
def build_ffn(hidden_size: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None, **kwargs: Any) -> nn.Module:
|
31 |
ffn_type = kwargs.pop('ffn_type')
|
32 |
if ffn_type == 'mptmlp':
|
33 |
if len(kwargs) > 0:
|
34 |
raise ValueError(f'MPTMLP got an unexpected keyword argument: {kwargs}')
|
35 |
+
return MPTMLP(hidden_size=hidden_size, expansion_ratio=expansion_ratio, fc_type=fc_type, device=device)
|
36 |
elif ffn_type == 'te_ln_mlp':
|
37 |
assert te is not None
|
38 |
+
return te.LayerNormMLP(hidden_size=hidden_size, ffn_hidden_size=hidden_size * expansion_ratio, **kwargs)
|
39 |
raise ValueError(f'ffn_type={ffn_type!r} not recognized.')
|
modeling_mpt.py
CHANGED
@@ -56,12 +56,12 @@ class MPTModel(MPTPreTrainedModel):
|
|
56 |
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
57 |
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
58 |
self.embedding_fraction = config.embedding_fraction
|
59 |
-
self.wte = SharedEmbedding(config.vocab_size, config.
|
60 |
if self.learned_pos_emb:
|
61 |
-
self.wpe = torch.nn.Embedding(config.max_seq_len, config.
|
62 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
63 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
64 |
-
self.norm_f = norm_class(config.
|
65 |
if config.init_device != 'meta':
|
66 |
log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
|
67 |
self.apply(self.param_init_fn)
|
@@ -213,7 +213,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
213 |
|
214 |
def param_init_fn(self, module: nn.Module) -> None:
|
215 |
init_fn_name = self.config.init_config['name']
|
216 |
-
MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers,
|
217 |
|
218 |
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
|
219 |
return isinstance(module, MPTBlock)
|
@@ -238,10 +238,10 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
238 |
if config.logit_scale is not None:
|
239 |
logit_scale = config.logit_scale
|
240 |
if isinstance(logit_scale, str):
|
241 |
-
if logit_scale == '
|
242 |
-
logit_scale = 1 / math.sqrt(config.
|
243 |
else:
|
244 |
-
raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or '
|
245 |
self.logit_scale = logit_scale
|
246 |
|
247 |
def get_input_embeddings(self) -> nn.Embedding:
|
@@ -282,7 +282,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
282 |
|
283 |
def param_init_fn(self, module: nn.Module) -> None:
|
284 |
init_fn_name = self.config.init_config['name']
|
285 |
-
MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers,
|
286 |
|
287 |
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
|
288 |
return isinstance(module, MPTBlock)
|
|
|
56 |
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
57 |
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
58 |
self.embedding_fraction = config.embedding_fraction
|
59 |
+
self.wte = SharedEmbedding(config.vocab_size, config.hidden_size, device=config.init_device)
|
60 |
if self.learned_pos_emb:
|
61 |
+
self.wpe = torch.nn.Embedding(config.max_seq_len, config.hidden_size, device=config.init_device)
|
62 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
63 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
64 |
+
self.norm_f = norm_class(config.hidden_size, device=config.init_device)
|
65 |
if config.init_device != 'meta':
|
66 |
log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
|
67 |
self.apply(self.param_init_fn)
|
|
|
213 |
|
214 |
def param_init_fn(self, module: nn.Module) -> None:
|
215 |
init_fn_name = self.config.init_config['name']
|
216 |
+
MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, hidden_size=self.config.hidden_size, **self.config.init_config)
|
217 |
|
218 |
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
|
219 |
return isinstance(module, MPTBlock)
|
|
|
238 |
if config.logit_scale is not None:
|
239 |
logit_scale = config.logit_scale
|
240 |
if isinstance(logit_scale, str):
|
241 |
+
if logit_scale == 'inv_sqrt_hidden_size':
|
242 |
+
logit_scale = 1 / math.sqrt(config.hidden_size)
|
243 |
else:
|
244 |
+
raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_hidden_size'.")
|
245 |
self.logit_scale = logit_scale
|
246 |
|
247 |
def get_input_embeddings(self) -> nn.Embedding:
|
|
|
282 |
|
283 |
def param_init_fn(self, module: nn.Module) -> None:
|
284 |
init_fn_name = self.config.init_config['name']
|
285 |
+
MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, hidden_size=self.config.hidden_size, **self.config.init_config)
|
286 |
|
287 |
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
|
288 |
return isinstance(module, MPTBlock)
|
param_init_fns.py
CHANGED
@@ -29,7 +29,7 @@ def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
|
|
29 |
slice_indices[dim] = slice(s, e)
|
30 |
init_fn_(module.weight[slice_indices])
|
31 |
|
32 |
-
def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
|
33 |
del kwargs
|
34 |
init_div_is_residual = init_div_is_residual
|
35 |
if init_div_is_residual is False:
|
@@ -85,8 +85,8 @@ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
|
|
85 |
if module._qkv_same_embed_dim:
|
86 |
assert module.in_proj_weight is not None
|
87 |
assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
|
88 |
-
assert
|
89 |
-
_d =
|
90 |
splits = (0, _d, 2 * _d, 3 * _d)
|
91 |
for (s, e) in zip(splits[:-1], splits[1:]):
|
92 |
init_fn_(module.in_proj_weight[s:e])
|
@@ -130,23 +130,23 @@ def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int,
|
|
130 |
def _normal_init_(std: float, mean: float=0.0) -> Callable:
|
131 |
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
132 |
|
133 |
-
def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int,
|
134 |
del kwargs
|
135 |
init_fn_ = _normal_init_(std=std)
|
136 |
-
generic_param_init_fn_(module=module, init_fn_=init_fn_,
|
137 |
|
138 |
-
def baseline_param_init_fn_(module: nn.Module, init_std: Optional[float], n_layers: int,
|
139 |
del kwargs
|
140 |
if init_std is None:
|
141 |
raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
|
142 |
-
_normal_param_init_fn_(module=module, std=init_std,
|
143 |
|
144 |
-
def small_param_init_fn_(module: nn.Module, n_layers: int,
|
145 |
del kwargs
|
146 |
-
std = math.sqrt(2 / (5 *
|
147 |
-
_normal_param_init_fn_(module=module, std=std,
|
148 |
|
149 |
-
def neox_param_init_fn_(module: nn.Module, n_layers: int,
|
150 |
"""From section 2.3.1 of GPT-NeoX-20B:
|
151 |
|
152 |
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
@@ -155,25 +155,25 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init
|
|
155 |
"""
|
156 |
del kwargs
|
157 |
residual_div = n_layers / math.sqrt(10)
|
158 |
-
small_param_init_fn_(module=module,
|
159 |
|
160 |
-
def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int,
|
161 |
del kwargs
|
162 |
kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
163 |
-
generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_,
|
164 |
|
165 |
-
def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int,
|
166 |
del kwargs
|
167 |
kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
168 |
-
generic_param_init_fn_(module=module, init_fn_=kaiming_normal_,
|
169 |
|
170 |
-
def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int,
|
171 |
del kwargs
|
172 |
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
173 |
-
generic_param_init_fn_(module=module, init_fn_=xavier_uniform_,
|
174 |
|
175 |
-
def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int,
|
176 |
del kwargs
|
177 |
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
178 |
-
generic_param_init_fn_(module=module, init_fn_=xavier_normal_,
|
179 |
MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
|
|
|
29 |
slice_indices[dim] = slice(s, e)
|
30 |
init_fn_(module.weight[slice_indices])
|
31 |
|
32 |
+
def generic_param_init_fn_(module: nn.Module, init_fn_: Callable, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
33 |
del kwargs
|
34 |
init_div_is_residual = init_div_is_residual
|
35 |
if init_div_is_residual is False:
|
|
|
85 |
if module._qkv_same_embed_dim:
|
86 |
assert module.in_proj_weight is not None
|
87 |
assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
|
88 |
+
assert hidden_size is not None
|
89 |
+
_d = hidden_size
|
90 |
splits = (0, _d, 2 * _d, 3 * _d)
|
91 |
for (s, e) in zip(splits[:-1], splits[1:]):
|
92 |
init_fn_(module.in_proj_weight[s:e])
|
|
|
130 |
def _normal_init_(std: float, mean: float=0.0) -> Callable:
|
131 |
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
132 |
|
133 |
+
def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
134 |
del kwargs
|
135 |
init_fn_ = _normal_init_(std=std)
|
136 |
+
generic_param_init_fn_(module=module, init_fn_=init_fn_, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
137 |
|
138 |
+
def baseline_param_init_fn_(module: nn.Module, init_std: Optional[float], n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
139 |
del kwargs
|
140 |
if init_std is None:
|
141 |
raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
|
142 |
+
_normal_param_init_fn_(module=module, std=init_std, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
143 |
|
144 |
+
def small_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
145 |
del kwargs
|
146 |
+
std = math.sqrt(2 / (5 * hidden_size))
|
147 |
+
_normal_param_init_fn_(module=module, std=std, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
148 |
|
149 |
+
def neox_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, **kwargs: Any) -> None:
|
150 |
"""From section 2.3.1 of GPT-NeoX-20B:
|
151 |
|
152 |
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
|
|
155 |
"""
|
156 |
del kwargs
|
157 |
residual_div = n_layers / math.sqrt(10)
|
158 |
+
small_param_init_fn_(module=module, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
159 |
|
160 |
+
def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
|
161 |
del kwargs
|
162 |
kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
163 |
+
generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
164 |
|
165 |
+
def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', **kwargs: Any) -> None:
|
166 |
del kwargs
|
167 |
kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
168 |
+
generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
169 |
|
170 |
+
def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
|
171 |
del kwargs
|
172 |
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
173 |
+
generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
174 |
|
175 |
+
def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, hidden_size: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, **kwargs: Any) -> None:
|
176 |
del kwargs
|
177 |
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
178 |
+
generic_param_init_fn_(module=module, init_fn_=xavier_normal_, hidden_size=hidden_size, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim)
|
179 |
MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
|