pszemraj commited on
Commit
21986ed
1 Parent(s): 0688e28

🎨 format for readability

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (8) hide show
  1. adapt_tokenizer.py +8 -5
  2. attention.py +287 -70
  3. blocks.py +58 -11
  4. configuration_mpt.py +103 -28
  5. hf_prefixlm_converter.py +440 -102
  6. meta_init_context.py +26 -10
  7. norm.py +67 -17
  8. param_init_fns.py +288 -52
adapt_tokenizer.py CHANGED
@@ -1,8 +1,10 @@
1
  from typing import Union
2
  from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
 
3
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
4
  NUM_SENTINEL_TOKENS: int = 100
5
 
 
6
  def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
7
  """Adds sentinel tokens and padding token (if missing).
8
 
@@ -12,16 +14,17 @@ def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
12
  All added tokens are added as special tokens. No tokens are
13
  added if sentinel tokens and padding token already exist.
14
  """
15
- sentinels_to_add = [f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)]
16
  tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
17
  if tokenizer.pad_token is None:
18
- tokenizer.add_tokens('<pad>', special_tokens=True)
19
- tokenizer.pad_token = '<pad>'
20
  assert tokenizer.pad_token_id is not None
21
- sentinels = ''.join([f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)])
22
  _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
23
  tokenizer.sentinel_token_ids = _sentinel_token_ids
24
 
 
25
  class AutoTokenizerForMOD(AutoTokenizer):
26
  """AutoTokenizer + Adaptation for MOD.
27
 
@@ -38,4 +41,4 @@ class AutoTokenizerForMOD(AutoTokenizer):
38
  """See `AutoTokenizer.from_pretrained` docstring."""
39
  tokenizer = super().from_pretrained(*args, **kwargs)
40
  adapt_tokenizer_for_denoising(tokenizer)
41
- return tokenizer
 
1
  from typing import Union
2
  from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3
+
4
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
5
  NUM_SENTINEL_TOKENS: int = 100
6
 
7
+
8
  def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
9
  """Adds sentinel tokens and padding token (if missing).
10
 
 
14
  All added tokens are added as special tokens. No tokens are
15
  added if sentinel tokens and padding token already exist.
16
  """
17
+ sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
18
  tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
19
  if tokenizer.pad_token is None:
20
+ tokenizer.add_tokens("<pad>", special_tokens=True)
21
+ tokenizer.pad_token = "<pad>"
22
  assert tokenizer.pad_token_id is not None
23
+ sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
24
  _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
25
  tokenizer.sentinel_token_ids = _sentinel_token_ids
26
 
27
+
28
  class AutoTokenizerForMOD(AutoTokenizer):
29
  """AutoTokenizer + Adaptation for MOD.
30
 
 
41
  """See `AutoTokenizer.from_pretrained` docstring."""
42
  tokenizer = super().from_pretrained(*args, **kwargs)
43
  adapt_tokenizer_for_denoising(tokenizer)
44
+ return tokenizer
attention.py CHANGED
@@ -8,18 +8,37 @@ from einops import rearrange
8
  from torch import nn
9
  from .norm import LPLayerNorm
10
 
11
- def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
 
 
 
12
  if original_is_causal and num_query_tokens != num_key_tokens:
13
  if num_query_tokens != 1:
14
- raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
 
 
15
  else:
16
  return False
17
  return original_is_causal
18
 
19
- def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
20
- q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
21
- k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
22
- v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  min_val = torch.finfo(q.dtype).min
24
  (b, _, s_q, d) = q.shape
25
  s_k = k.size(-1)
@@ -27,13 +46,27 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
27
  softmax_scale = 1 / math.sqrt(d)
28
  attn_weight = q.matmul(k) * softmax_scale
29
  if attn_bias is not None:
30
- if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
31
- raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
 
 
 
 
 
 
32
  attn_weight = attn_weight + attn_bias
33
  if key_padding_mask is not None:
34
  if attn_bias is not None:
35
- warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
36
- attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
 
 
 
 
 
 
 
 
37
  if is_causal:
38
  s = max(s_q, s_k)
39
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
@@ -44,74 +77,146 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
44
  attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
45
  attn_weight = torch.softmax(attn_weight, dim=-1)
46
  if dropout_p:
47
- attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
 
 
48
  out = attn_weight.matmul(v)
49
- out = rearrange(out, 'b h s d -> b s (h d)')
50
  if needs_weights:
51
  return (out, attn_weight)
52
  return (out, None)
53
 
 
54
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
55
  for tensor in tensors:
56
  if tensor.dtype not in valid_dtypes:
57
- raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
 
 
58
  if not tensor.is_cuda:
59
- raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
 
 
 
60
 
61
- def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  try:
63
  from flash_attn import bert_padding, flash_attn_interface
64
  except:
65
- raise RuntimeError('Please install flash-attn==1.0.3.post0')
66
  check_valid_inputs(query, key, value)
67
  if attn_bias is not None:
68
- raise NotImplementedError(f'attn_bias not implemented for flash attn.')
69
  (batch_size, seqlen) = query.shape[:2]
70
  if key_padding_mask is None:
71
  key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
72
- query_padding_mask = key_padding_mask[:, -query.size(1):]
73
- (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
74
- query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
75
- (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
76
- key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
 
 
 
 
 
 
77
  (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
78
- value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
 
 
79
  if multiquery:
80
  key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
81
- value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
 
 
82
  dropout_p = dropout_p if training else 0.0
83
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
84
- output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
85
- output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  return (output, None)
87
 
88
- def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  try:
90
  from flash_attn import flash_attn_triton
91
  except:
92
- raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
 
 
93
  check_valid_inputs(query, key, value)
94
  if dropout_p:
95
- raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
96
  if needs_weights:
97
- raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
98
  if key_padding_mask is not None:
99
- warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
 
 
 
 
 
 
100
  (b_size, s_k) = key_padding_mask.shape[:2]
101
  if attn_bias is None:
102
  attn_bias = query.new_zeros(b_size, 1, 1, s_k)
103
- attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
104
- query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
105
- key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
106
- value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
 
 
107
  if multiquery:
108
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
109
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
110
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
111
- attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
 
 
112
  output = attn_output.view(*attn_output.shape[:2], -1)
113
  return (output, None)
114
 
 
115
  class MultiheadAttention(nn.Module):
116
  """Multi-head self attention.
117
 
@@ -119,7 +224,18 @@ class MultiheadAttention(nn.Module):
119
  additive bias.
120
  """
121
 
122
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
 
 
 
 
 
 
 
 
 
 
 
123
  super().__init__()
124
  self.attn_impl = attn_impl
125
  self.clip_qkv = clip_qkv
@@ -137,21 +253,38 @@ class MultiheadAttention(nn.Module):
137
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
138
  self.q_ln = layernorm_class(self.d_model, device=device)
139
  self.k_ln = layernorm_class(self.d_model, device=device)
140
- if self.attn_impl == 'flash':
141
  self.attn_fn = flash_attn_fn
142
- elif self.attn_impl == 'triton':
143
  self.attn_fn = triton_flash_attn_fn
144
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
145
- elif self.attn_impl == 'torch':
 
 
 
 
 
146
  self.attn_fn = scaled_multihead_dot_product_attention
147
  if torch.cuda.is_available():
148
- warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
 
 
 
 
149
  else:
150
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
151
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
152
  self.out_proj._is_residual = True
153
 
154
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
 
 
 
 
 
 
 
 
155
  qkv = self.Wqkv(x)
156
  if self.clip_qkv:
157
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
@@ -167,10 +300,23 @@ class MultiheadAttention(nn.Module):
167
  value = torch.cat([past_key_value[1], value], dim=1)
168
  past_key_value = (key, value)
169
  if attn_bias is not None:
170
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
171
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
 
 
 
 
 
 
 
 
 
 
 
 
172
  return (self.out_proj(context), attn_weights, past_key_value)
173
 
 
174
  class MultiQueryAttention(nn.Module):
175
  """Multi-Query self attention.
176
 
@@ -178,7 +324,18 @@ class MultiQueryAttention(nn.Module):
178
  additive bias.
179
  """
180
 
181
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
 
 
 
 
 
 
 
 
 
 
 
182
  super().__init__()
183
  self.attn_impl = attn_impl
184
  self.clip_qkv = clip_qkv
@@ -197,25 +354,44 @@ class MultiQueryAttention(nn.Module):
197
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
198
  self.q_ln = layernorm_class(d_model, device=device)
199
  self.k_ln = layernorm_class(self.head_dim, device=device)
200
- if self.attn_impl == 'flash':
201
  self.attn_fn = flash_attn_fn
202
- elif self.attn_impl == 'triton':
203
  self.attn_fn = triton_flash_attn_fn
204
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
205
- elif self.attn_impl == 'torch':
 
 
 
 
 
206
  self.attn_fn = scaled_multihead_dot_product_attention
207
  if torch.cuda.is_available():
208
- warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
 
 
 
 
209
  else:
210
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
211
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
212
  self.out_proj._is_residual = True
213
 
214
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
 
 
 
 
 
 
 
 
215
  qkv = self.Wqkv(x)
216
  if self.clip_qkv:
217
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
218
- (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
 
 
219
  key_padding_mask = attention_mask
220
  if self.qk_ln:
221
  dtype = query.dtype
@@ -227,14 +403,30 @@ class MultiQueryAttention(nn.Module):
227
  value = torch.cat([past_key_value[1], value], dim=1)
228
  past_key_value = (key, value)
229
  if attn_bias is not None:
230
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
231
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  return (self.out_proj(context), attn_weights, past_key_value)
233
 
234
- def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
235
- if attn_impl == 'flash':
 
 
 
236
  return None
237
- elif attn_impl in ['torch', 'triton']:
238
  if alibi:
239
  if (prefix_lm or not causal) or use_sequence_id:
240
  return (1, n_heads, seq_len, seq_len)
@@ -243,18 +435,31 @@ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_s
243
  return (1, 1, seq_len, seq_len)
244
  return None
245
  else:
246
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
247
 
248
- def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
249
- if attn_impl == 'flash':
 
 
250
  return None
251
- elif attn_impl in ['torch', 'triton']:
252
  if alibi:
253
  (device, dtype) = (attn_bias.device, attn_bias.dtype)
254
- attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
 
 
 
 
 
 
 
 
 
255
  return attn_bias
256
  else:
257
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
258
 
259
  def gen_slopes(n_heads, alibi_bias_max=8, device=None):
260
  _n_heads = 2 ** math.ceil(math.log2(n_heads))
@@ -265,12 +470,24 @@ def gen_slopes(n_heads, alibi_bias_max=8, device=None):
265
  slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
266
  return slopes.view(1, n_heads, 1, 1)
267
 
268
- def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
269
- alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
 
 
 
 
 
270
  if full:
271
- alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
 
 
272
  alibi_bias = alibi_bias.abs().mul(-1)
273
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
274
  alibi_bias = alibi_bias * slopes
275
  return alibi_bias.to(dtype=dtype)
276
- ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
 
 
 
 
 
 
8
  from torch import nn
9
  from .norm import LPLayerNorm
10
 
11
+
12
+ def _reset_is_causal(
13
+ num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
14
+ ):
15
  if original_is_causal and num_query_tokens != num_key_tokens:
16
  if num_query_tokens != 1:
17
+ raise NotImplementedError(
18
+ "MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
19
+ )
20
  else:
21
  return False
22
  return original_is_causal
23
 
24
+
25
+ def scaled_multihead_dot_product_attention(
26
+ query,
27
+ key,
28
+ value,
29
+ n_heads,
30
+ softmax_scale=None,
31
+ attn_bias=None,
32
+ key_padding_mask=None,
33
+ is_causal=False,
34
+ dropout_p=0.0,
35
+ training=False,
36
+ needs_weights=False,
37
+ multiquery=False,
38
+ ):
39
+ q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
40
+ k = rearrange(key, "b s (h d) -> b h d s", h=1 if multiquery else n_heads)
41
+ v = rearrange(value, "b s (h d) -> b h s d", h=1 if multiquery else n_heads)
42
  min_val = torch.finfo(q.dtype).min
43
  (b, _, s_q, d) = q.shape
44
  s_k = k.size(-1)
 
46
  softmax_scale = 1 / math.sqrt(d)
47
  attn_weight = q.matmul(k) * softmax_scale
48
  if attn_bias is not None:
49
+ if (
50
+ attn_bias.size(-1) != 1
51
+ and attn_bias.size(-1) != s_k
52
+ or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
53
+ ):
54
+ raise RuntimeError(
55
+ f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
56
+ )
57
  attn_weight = attn_weight + attn_bias
58
  if key_padding_mask is not None:
59
  if attn_bias is not None:
60
+ warnings.warn(
61
+ "Propogating key_padding_mask to the attention module "
62
+ + "and applying it within the attention module can cause "
63
+ + "unneccessary computation/memory usage. Consider integrating "
64
+ + "into attn_bias once and passing that to each attention "
65
+ + "module instead."
66
+ )
67
+ attn_weight = attn_weight.masked_fill(
68
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val
69
+ )
70
  if is_causal:
71
  s = max(s_q, s_k)
72
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
 
77
  attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
78
  attn_weight = torch.softmax(attn_weight, dim=-1)
79
  if dropout_p:
80
+ attn_weight = torch.nn.functional.dropout(
81
+ attn_weight, p=dropout_p, training=training, inplace=True
82
+ )
83
  out = attn_weight.matmul(v)
84
+ out = rearrange(out, "b h s d -> b s (h d)")
85
  if needs_weights:
86
  return (out, attn_weight)
87
  return (out, None)
88
 
89
+
90
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
91
  for tensor in tensors:
92
  if tensor.dtype not in valid_dtypes:
93
+ raise TypeError(
94
+ f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
95
+ )
96
  if not tensor.is_cuda:
97
+ raise TypeError(
98
+ f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
99
+ )
100
+
101
 
102
+ def flash_attn_fn(
103
+ query,
104
+ key,
105
+ value,
106
+ n_heads,
107
+ softmax_scale=None,
108
+ attn_bias=None,
109
+ key_padding_mask=None,
110
+ is_causal=False,
111
+ dropout_p=0.0,
112
+ training=False,
113
+ needs_weights=False,
114
+ multiquery=False,
115
+ ):
116
  try:
117
  from flash_attn import bert_padding, flash_attn_interface
118
  except:
119
+ raise RuntimeError("Please install flash-attn==1.0.3.post0")
120
  check_valid_inputs(query, key, value)
121
  if attn_bias is not None:
122
+ raise NotImplementedError(f"attn_bias not implemented for flash attn.")
123
  (batch_size, seqlen) = query.shape[:2]
124
  if key_padding_mask is None:
125
  key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
126
+ query_padding_mask = key_padding_mask[:, -query.size(1) :]
127
+ (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
128
+ query, query_padding_mask
129
+ )
130
+ query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
131
+ (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
132
+ key, key_padding_mask
133
+ )
134
+ key_unpad = rearrange(
135
+ key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
136
+ )
137
  (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
138
+ value_unpad = rearrange(
139
+ value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
140
+ )
141
  if multiquery:
142
  key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
143
+ value_unpad = value_unpad.expand(
144
+ value_unpad.size(0), n_heads, value_unpad.size(-1)
145
+ )
146
  dropout_p = dropout_p if training else 0.0
147
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
148
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
149
+ query_unpad,
150
+ key_unpad,
151
+ value_unpad,
152
+ cu_seqlens_q,
153
+ cu_seqlens_k,
154
+ max_seqlen_q,
155
+ max_seqlen_k,
156
+ dropout_p,
157
+ softmax_scale=softmax_scale,
158
+ causal=reset_is_causal,
159
+ return_attn_probs=needs_weights,
160
+ )
161
+ output = bert_padding.pad_input(
162
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
163
+ )
164
  return (output, None)
165
 
166
+
167
+ def triton_flash_attn_fn(
168
+ query,
169
+ key,
170
+ value,
171
+ n_heads,
172
+ softmax_scale=None,
173
+ attn_bias=None,
174
+ key_padding_mask=None,
175
+ is_causal=False,
176
+ dropout_p=0.0,
177
+ training=False,
178
+ needs_weights=False,
179
+ multiquery=False,
180
+ ):
181
  try:
182
  from flash_attn import flash_attn_triton
183
  except:
184
+ raise RuntimeError(
185
+ "Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202"
186
+ )
187
  check_valid_inputs(query, key, value)
188
  if dropout_p:
189
+ raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
190
  if needs_weights:
191
+ raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
192
  if key_padding_mask is not None:
193
+ warnings.warn(
194
+ "Propagating key_padding_mask to the attention module "
195
+ + "and applying it within the attention module can cause "
196
+ + "unnecessary computation/memory usage. Consider integrating "
197
+ + "into attn_bias once and passing that to each attention "
198
+ + "module instead."
199
+ )
200
  (b_size, s_k) = key_padding_mask.shape[:2]
201
  if attn_bias is None:
202
  attn_bias = query.new_zeros(b_size, 1, 1, s_k)
203
+ attn_bias = attn_bias.masked_fill(
204
+ ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
205
+ )
206
+ query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
207
+ key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
208
+ value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
209
  if multiquery:
210
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
211
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
212
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
213
+ attn_output = flash_attn_triton.flash_attn_func(
214
+ query, key, value, attn_bias, reset_is_causal, softmax_scale
215
+ )
216
  output = attn_output.view(*attn_output.shape[:2], -1)
217
  return (output, None)
218
 
219
+
220
  class MultiheadAttention(nn.Module):
221
  """Multi-head self attention.
222
 
 
224
  additive bias.
225
  """
226
 
227
+ def __init__(
228
+ self,
229
+ d_model: int,
230
+ n_heads: int,
231
+ attn_impl: str = "triton",
232
+ clip_qkv: Optional[float] = None,
233
+ qk_ln: bool = False,
234
+ softmax_scale: Optional[float] = None,
235
+ attn_pdrop: float = 0.0,
236
+ low_precision_layernorm: bool = False,
237
+ device: Optional[str] = None,
238
+ ):
239
  super().__init__()
240
  self.attn_impl = attn_impl
241
  self.clip_qkv = clip_qkv
 
253
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
254
  self.q_ln = layernorm_class(self.d_model, device=device)
255
  self.k_ln = layernorm_class(self.d_model, device=device)
256
+ if self.attn_impl == "flash":
257
  self.attn_fn = flash_attn_fn
258
+ elif self.attn_impl == "triton":
259
  self.attn_fn = triton_flash_attn_fn
260
+ warnings.warn(
261
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
262
+ + "it uses more memory. When training larger models this can trigger "
263
+ + "alloc retries which hurts performance. If encountered, we recommend "
264
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
265
+ )
266
+ elif self.attn_impl == "torch":
267
  self.attn_fn = scaled_multihead_dot_product_attention
268
  if torch.cuda.is_available():
269
+ warnings.warn(
270
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
271
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
272
+ + "we recommend using `attn_impl: triton`."
273
+ )
274
  else:
275
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
276
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
277
  self.out_proj._is_residual = True
278
 
279
+ def forward(
280
+ self,
281
+ x,
282
+ past_key_value=None,
283
+ attn_bias=None,
284
+ attention_mask=None,
285
+ is_causal=True,
286
+ needs_weights=False,
287
+ ):
288
  qkv = self.Wqkv(x)
289
  if self.clip_qkv:
290
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
 
300
  value = torch.cat([past_key_value[1], value], dim=1)
301
  past_key_value = (key, value)
302
  if attn_bias is not None:
303
+ attn_bias = attn_bias[:, :, -query.size(1) :, -key.size(1) :]
304
+ (context, attn_weights) = self.attn_fn(
305
+ query,
306
+ key,
307
+ value,
308
+ self.n_heads,
309
+ softmax_scale=self.softmax_scale,
310
+ attn_bias=attn_bias,
311
+ key_padding_mask=key_padding_mask,
312
+ is_causal=is_causal,
313
+ dropout_p=self.attn_dropout_p,
314
+ training=self.training,
315
+ needs_weights=needs_weights,
316
+ )
317
  return (self.out_proj(context), attn_weights, past_key_value)
318
 
319
+
320
  class MultiQueryAttention(nn.Module):
321
  """Multi-Query self attention.
322
 
 
324
  additive bias.
325
  """
326
 
327
+ def __init__(
328
+ self,
329
+ d_model: int,
330
+ n_heads: int,
331
+ attn_impl: str = "triton",
332
+ clip_qkv: Optional[float] = None,
333
+ qk_ln: bool = False,
334
+ softmax_scale: Optional[float] = None,
335
+ attn_pdrop: float = 0.0,
336
+ low_precision_layernorm: bool = False,
337
+ device: Optional[str] = None,
338
+ ):
339
  super().__init__()
340
  self.attn_impl = attn_impl
341
  self.clip_qkv = clip_qkv
 
354
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
355
  self.q_ln = layernorm_class(d_model, device=device)
356
  self.k_ln = layernorm_class(self.head_dim, device=device)
357
+ if self.attn_impl == "flash":
358
  self.attn_fn = flash_attn_fn
359
+ elif self.attn_impl == "triton":
360
  self.attn_fn = triton_flash_attn_fn
361
+ warnings.warn(
362
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
363
+ + "it uses more memory. When training larger models this can trigger "
364
+ + "alloc retries which hurts performance. If encountered, we recommend "
365
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
366
+ )
367
+ elif self.attn_impl == "torch":
368
  self.attn_fn = scaled_multihead_dot_product_attention
369
  if torch.cuda.is_available():
370
+ warnings.warn(
371
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
372
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
373
+ + "we recommend using `attn_impl: triton`."
374
+ )
375
  else:
376
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
377
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
378
  self.out_proj._is_residual = True
379
 
380
+ def forward(
381
+ self,
382
+ x,
383
+ past_key_value=None,
384
+ attn_bias=None,
385
+ attention_mask=None,
386
+ is_causal=True,
387
+ needs_weights=False,
388
+ ):
389
  qkv = self.Wqkv(x)
390
  if self.clip_qkv:
391
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
392
+ (query, key, value) = qkv.split(
393
+ [self.d_model, self.head_dim, self.head_dim], dim=2
394
+ )
395
  key_padding_mask = attention_mask
396
  if self.qk_ln:
397
  dtype = query.dtype
 
403
  value = torch.cat([past_key_value[1], value], dim=1)
404
  past_key_value = (key, value)
405
  if attn_bias is not None:
406
+ attn_bias = attn_bias[:, :, -query.size(1) :, -key.size(1) :]
407
+ (context, attn_weights) = self.attn_fn(
408
+ query,
409
+ key,
410
+ value,
411
+ self.n_heads,
412
+ softmax_scale=self.softmax_scale,
413
+ attn_bias=attn_bias,
414
+ key_padding_mask=key_padding_mask,
415
+ is_causal=is_causal,
416
+ dropout_p=self.attn_dropout_p,
417
+ training=self.training,
418
+ needs_weights=needs_weights,
419
+ multiquery=True,
420
+ )
421
  return (self.out_proj(context), attn_weights, past_key_value)
422
 
423
+
424
+ def attn_bias_shape(
425
+ attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id
426
+ ):
427
+ if attn_impl == "flash":
428
  return None
429
+ elif attn_impl in ["torch", "triton"]:
430
  if alibi:
431
  if (prefix_lm or not causal) or use_sequence_id:
432
  return (1, n_heads, seq_len, seq_len)
 
435
  return (1, 1, seq_len, seq_len)
436
  return None
437
  else:
438
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
439
+
440
 
441
+ def build_attn_bias(
442
+ attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8
443
+ ):
444
+ if attn_impl == "flash":
445
  return None
446
+ elif attn_impl in ["torch", "triton"]:
447
  if alibi:
448
  (device, dtype) = (attn_bias.device, attn_bias.dtype)
449
+ attn_bias = attn_bias.add(
450
+ build_alibi_bias(
451
+ n_heads,
452
+ seq_len,
453
+ full=not causal,
454
+ alibi_bias_max=alibi_bias_max,
455
+ device=device,
456
+ dtype=dtype,
457
+ )
458
+ )
459
  return attn_bias
460
  else:
461
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
462
+
463
 
464
  def gen_slopes(n_heads, alibi_bias_max=8, device=None):
465
  _n_heads = 2 ** math.ceil(math.log2(n_heads))
 
470
  slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
471
  return slopes.view(1, n_heads, 1, 1)
472
 
473
+
474
+ def build_alibi_bias(
475
+ n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None
476
+ ):
477
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
478
+ 1, 1, 1, seq_len
479
+ )
480
  if full:
481
+ alibi_bias = alibi_bias - torch.arange(
482
+ 1 - seq_len, 1, dtype=torch.int32, device=device
483
+ ).view(1, 1, seq_len, 1)
484
  alibi_bias = alibi_bias.abs().mul(-1)
485
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
486
  alibi_bias = alibi_bias * slopes
487
  return alibi_bias.to(dtype=dtype)
488
+
489
+
490
+ ATTN_CLASS_REGISTRY = {
491
+ "multihead_attention": MultiheadAttention,
492
+ "multiquery_attention": MultiQueryAttention,
493
+ }
blocks.py CHANGED
@@ -5,37 +5,84 @@ import torch.nn as nn
5
  from .attention import ATTN_CLASS_REGISTRY
6
  from .norm import NORM_CLASS_REGISTRY
7
 
8
- class MPTMLP(nn.Module):
9
 
10
- def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
 
 
 
11
  super().__init__()
12
  self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13
- self.act = nn.GELU(approximate='none')
14
  self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15
  self.down_proj._is_residual = True
16
 
17
  def forward(self, x):
18
  return self.down_proj(self.act(self.up_proj(x)))
19
 
20
- class MPTBlock(nn.Module):
21
 
22
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
- attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
- self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
 
 
 
 
 
 
 
 
 
29
  self.norm_2 = norm_class(d_model, device=device)
30
- self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
 
 
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33
 
34
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
 
35
  a = self.norm_1(x)
36
- (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
 
 
 
 
 
 
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
- return (x, past_key_value)
 
5
  from .attention import ATTN_CLASS_REGISTRY
6
  from .norm import NORM_CLASS_REGISTRY
7
 
 
8
 
9
+ class MPTMLP(nn.Module):
10
+ def __init__(
11
+ self, d_model: int, expansion_ratio: int, device: Optional[str] = None
12
+ ):
13
  super().__init__()
14
  self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
15
+ self.act = nn.GELU(approximate="none")
16
  self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
17
  self.down_proj._is_residual = True
18
 
19
  def forward(self, x):
20
  return self.down_proj(self.act(self.up_proj(x)))
21
 
 
22
 
23
+ class MPTBlock(nn.Module):
24
+ def __init__(
25
+ self,
26
+ d_model: int,
27
+ n_heads: int,
28
+ expansion_ratio: int,
29
+ attn_config: Dict = {
30
+ "attn_type": "multihead_attention",
31
+ "attn_pdrop": 0.0,
32
+ "attn_impl": "triton",
33
+ "qk_ln": False,
34
+ "clip_qkv": None,
35
+ "softmax_scale": None,
36
+ "prefix_lm": False,
37
+ "attn_uses_sequence_id": False,
38
+ "alibi": False,
39
+ "alibi_bias_max": 8,
40
+ },
41
+ resid_pdrop: float = 0.0,
42
+ norm_type: str = "low_precision_layernorm",
43
+ device: Optional[str] = None,
44
+ **kwargs
45
+ ):
46
  del kwargs
47
  super().__init__()
48
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
49
+ attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
50
  self.norm_1 = norm_class(d_model, device=device)
51
+ self.attn = attn_class(
52
+ attn_impl=attn_config["attn_impl"],
53
+ clip_qkv=attn_config["clip_qkv"],
54
+ qk_ln=attn_config["qk_ln"],
55
+ softmax_scale=attn_config["softmax_scale"],
56
+ attn_pdrop=attn_config["attn_pdrop"],
57
+ d_model=d_model,
58
+ n_heads=n_heads,
59
+ device=device,
60
+ )
61
  self.norm_2 = norm_class(d_model, device=device)
62
+ self.ffn = MPTMLP(
63
+ d_model=d_model, expansion_ratio=expansion_ratio, device=device
64
+ )
65
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
66
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
67
 
68
+ def forward(
69
+ self,
70
+ x: torch.Tensor,
71
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
72
+ attn_bias: Optional[torch.Tensor] = None,
73
+ attention_mask: Optional[torch.ByteTensor] = None,
74
+ is_causal: bool = True,
75
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
76
  a = self.norm_1(x)
77
+ (b, _, past_key_value) = self.attn(
78
+ a,
79
+ past_key_value=past_key_value,
80
+ attn_bias=attn_bias,
81
+ attention_mask=attention_mask,
82
+ is_causal=is_causal,
83
+ )
84
  x = x + self.resid_attn_dropout(b)
85
  m = self.norm_2(x)
86
  n = self.ffn(m)
87
  x = x + self.resid_ffn_dropout(n)
88
+ return (x, past_key_value)
configuration_mpt.py CHANGED
@@ -1,13 +1,51 @@
1
  """A HuggingFace-style model configuration."""
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
- attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
- init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class MPTConfig(PretrainedConfig):
8
- model_type = 'mpt'
9
 
10
- def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  """The MPT configuration class.
12
 
13
  Args:
@@ -80,39 +118,76 @@ class MPTConfig(PretrainedConfig):
80
  self.norm_type = norm_type
81
  self.use_cache = use_cache
82
  self.init_config = init_config
83
- if 'name' in kwargs:
84
- del kwargs['name']
85
- if 'loss_fn' in kwargs:
86
- del kwargs['loss_fn']
87
  super().__init__(**kwargs)
88
  self._validate_config()
89
 
90
  def _set_config_defaults(self, config, config_defaults):
91
- for (k, v) in config_defaults.items():
92
  if k not in config:
93
  config[k] = v
94
  return config
95
 
96
  def _validate_config(self):
97
- self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
98
- self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
 
 
 
 
99
  if self.d_model % self.n_heads != 0:
100
- raise ValueError('d_model must be divisible by n_heads')
101
- if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
102
- raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
103
- if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
 
 
 
 
 
 
 
 
 
 
 
104
  raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
105
- if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
106
- raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
107
- if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
108
- raise NotImplementedError('alibi only implemented with torch and triton attention.')
109
- if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
110
- raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
112
- raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
113
- if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
114
- raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
115
- if self.init_config.get('name', None) is None:
116
- raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
117
- if not self.learned_pos_emb and (not self.attn_config['alibi']):
118
- raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.')
 
 
 
 
 
 
 
 
 
1
  """A HuggingFace-style model configuration."""
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
+
5
+ attn_config_defaults: Dict = {
6
+ "attn_type": "multihead_attention",
7
+ "attn_pdrop": 0.0,
8
+ "attn_impl": "triton",
9
+ "qk_ln": False,
10
+ "clip_qkv": None,
11
+ "softmax_scale": None,
12
+ "prefix_lm": False,
13
+ "attn_uses_sequence_id": False,
14
+ "alibi": False,
15
+ "alibi_bias_max": 8,
16
+ }
17
+ init_config_defaults: Dict = {
18
+ "name": "kaiming_normal_",
19
+ "fan_mode": "fan_in",
20
+ "init_nonlinearity": "relu",
21
+ }
22
+
23
 
24
  class MPTConfig(PretrainedConfig):
25
+ model_type = "mpt"
26
 
27
+ def __init__(
28
+ self,
29
+ d_model: int = 2048,
30
+ n_heads: int = 16,
31
+ n_layers: int = 24,
32
+ expansion_ratio: int = 4,
33
+ max_seq_len: int = 2048,
34
+ vocab_size: int = 50368,
35
+ resid_pdrop: float = 0.0,
36
+ emb_pdrop: float = 0.0,
37
+ learned_pos_emb: bool = True,
38
+ attn_config: Dict = attn_config_defaults,
39
+ init_device: str = "cpu",
40
+ logit_scale: Optional[Union[float, str]] = None,
41
+ no_bias: bool = False,
42
+ verbose: int = 0,
43
+ embedding_fraction: float = 1.0,
44
+ norm_type: str = "low_precision_layernorm",
45
+ use_cache: bool = False,
46
+ init_config: Dict = init_config_defaults,
47
+ **kwargs,
48
+ ):
49
  """The MPT configuration class.
50
 
51
  Args:
 
118
  self.norm_type = norm_type
119
  self.use_cache = use_cache
120
  self.init_config = init_config
121
+ if "name" in kwargs:
122
+ del kwargs["name"]
123
+ if "loss_fn" in kwargs:
124
+ del kwargs["loss_fn"]
125
  super().__init__(**kwargs)
126
  self._validate_config()
127
 
128
  def _set_config_defaults(self, config, config_defaults):
129
+ for k, v in config_defaults.items():
130
  if k not in config:
131
  config[k] = v
132
  return config
133
 
134
  def _validate_config(self):
135
+ self.attn_config = self._set_config_defaults(
136
+ self.attn_config, attn_config_defaults
137
+ )
138
+ self.init_config = self._set_config_defaults(
139
+ self.init_config, init_config_defaults
140
+ )
141
  if self.d_model % self.n_heads != 0:
142
+ raise ValueError("d_model must be divisible by n_heads")
143
+ if any(
144
+ (
145
+ prob < 0 or prob > 1
146
+ for prob in [
147
+ self.attn_config["attn_pdrop"],
148
+ self.resid_pdrop,
149
+ self.emb_pdrop,
150
+ ]
151
+ )
152
+ ):
153
+ raise ValueError(
154
+ "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
155
+ )
156
+ if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
157
  raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
158
+ if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [
159
+ "torch",
160
+ "triton",
161
+ ]:
162
+ raise NotImplementedError(
163
+ "prefix_lm only implemented with torch and triton attention."
164
+ )
165
+ if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in [
166
+ "torch",
167
+ "triton",
168
+ ]:
169
+ raise NotImplementedError(
170
+ "alibi only implemented with torch and triton attention."
171
+ )
172
+ if self.attn_config["attn_uses_sequence_id"] and self.attn_config[
173
+ "attn_impl"
174
+ ] not in ["torch", "triton"]:
175
+ raise NotImplementedError(
176
+ "attn_uses_sequence_id only implemented with torch and triton attention."
177
+ )
178
  if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
179
+ raise ValueError(
180
+ "model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!"
181
+ )
182
+ if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
183
+ raise ValueError(
184
+ f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
185
+ )
186
+ if self.init_config.get("name", None) is None:
187
+ raise ValueError(
188
+ f"self.init_config={self.init_config!r} 'name' needs to be set."
189
+ )
190
+ if not self.learned_pos_emb and (not self.attn_config["alibi"]):
191
+ raise ValueError(
192
+ f"Positional information must be provided to the model using either learned_pos_emb or alibi."
193
+ )
hf_prefixlm_converter.py CHANGED
@@ -11,9 +11,17 @@ import warnings
11
  from types import MethodType
12
  from typing import Any, Dict, List, Optional, Tuple, Union
13
  import torch
14
- from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
 
 
 
 
 
 
15
  from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
16
- from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
 
 
17
  from transformers.models.bloom.modeling_bloom import logging
18
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
19
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
@@ -21,10 +29,21 @@ from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
21
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
22
  from transformers.models.opt.modeling_opt import OPTForCausalLM
23
  from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
24
- from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
 
 
 
25
  logger = logging.get_logger(__name__)
26
- _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
27
- CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
 
 
 
 
 
 
 
 
28
 
29
  def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
30
  """Converts a GPT-style Causal LM to a Prefix LM.
@@ -37,10 +56,12 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
37
 
38
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
39
  """
40
- if hasattr(model, '_prefix_lm_converted'):
41
  return model
42
  assert isinstance(model, _SUPPORTED_GPT_MODELS)
43
- assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models'
 
 
44
 
45
  def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
46
  """Helper that gets a list of the model's attention modules.
@@ -56,7 +77,7 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
56
  blocks = model.transformer.h
57
  for block in blocks:
58
  if isinstance(model, GPTNeoForCausalLM):
59
- if block.attn.attention_type != 'global':
60
  continue
61
  attn_module = block.attn.attention
62
  elif isinstance(model, GPTNeoXForCausalLM):
@@ -65,17 +86,58 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
65
  attn_module = block.attn
66
  attn_modules.append(attn_module)
67
  return attn_modules
68
- setattr(model, '_original_forward', getattr(model, 'forward'))
69
- setattr(model, '_original_generate', getattr(model, 'generate'))
70
 
71
- def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  """Wraps original forward to enable PrefixLM attention."""
73
 
74
  def call_og_forward():
75
  if isinstance(self, GPTNeoXForCausalLM):
76
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
77
  else:
78
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if bidirectional_mask is None:
80
  return call_og_forward()
81
  assert isinstance(bidirectional_mask, torch.Tensor)
@@ -83,14 +145,23 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
83
  (b, s) = bidirectional_mask.shape
84
  max_length = attn_modules[0].bias.shape[-1]
85
  if s > max_length:
86
- raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).')
 
 
 
87
  assert s <= max_length
88
  if s < max_length:
89
- pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
 
 
 
 
90
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
91
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
92
  for attn_module in attn_modules:
93
- attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
 
 
94
  output = call_og_forward()
95
  for attn_module in attn_modules:
96
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
@@ -105,11 +176,13 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
105
  for attn_module in attn_modules:
106
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
107
  return output
108
- setattr(model, 'forward', MethodType(forward, model))
109
- setattr(model, 'generate', MethodType(generate, model))
110
- setattr(model, '_prefix_lm_converted', True)
 
111
  return model
112
 
 
113
  def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
114
  """Converts a BLOOM Causal LM to a Prefix LM.
115
 
@@ -118,62 +191,137 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
118
 
119
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
120
  """
121
- if hasattr(model, '_prefix_lm_converted'):
122
  return model
123
  assert isinstance(model, BloomForCausalLM)
124
- assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'
125
-
126
- def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
 
 
 
 
 
 
 
 
127
  combined_attention_mask = None
128
  device = attention_mask.device
129
  (_, src_length) = input_shape
130
  if src_length > 1:
131
- combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
 
 
 
 
132
  if bidirectional_mask is not None:
133
  assert attention_mask.shape == bidirectional_mask.shape
134
- expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
135
- combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
 
 
 
 
136
  expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
137
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
 
 
 
 
138
  return combined_attention_mask
139
 
140
- def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
 
 
 
 
 
 
 
141
  num_heads = self.config.n_head
142
  closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
143
- base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
144
- powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
 
 
 
 
 
 
145
  slopes = torch.pow(base, powers)
146
  if closest_power_of_2 != num_heads:
147
- extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32)
148
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
149
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
 
 
 
 
 
 
 
 
150
  slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
151
  qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
152
  ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
153
  diffs = qa - ka + key_length - query_length
154
  diffs = -diffs.abs()
155
- alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
156
- alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
 
 
 
 
157
  return alibi.to(dtype)
 
158
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
159
 
160
- def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
161
- if deprecated_arguments.pop('position_ids', False) is not False:
162
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  if len(deprecated_arguments) > 0:
164
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
165
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
166
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
 
 
 
167
  use_cache = use_cache if use_cache is not None else self.config.use_cache
168
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
169
  if input_ids is not None and inputs_embeds is not None:
170
- raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
 
 
171
  elif input_ids is not None:
172
  (batch_size, seq_length) = input_ids.shape
173
  elif inputs_embeds is not None:
174
  (batch_size, seq_length, _) = inputs_embeds.shape
175
  else:
176
- raise ValueError('You have to specify either input_ids or inputs_embeds')
177
  if past_key_values is None:
178
  past_key_values = tuple([None] * len(self.h))
179
  head_mask = self.get_head_mask(head_mask, self.config.n_layer)
@@ -190,28 +338,62 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
190
  past_key_values_length = tmp.shape[2]
191
  seq_length_with_past = seq_length_with_past + past_key_values_length
192
  if attention_mask is None:
193
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
 
 
194
  else:
195
  attention_mask = attention_mask.to(hidden_states.device)
196
- alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device)
197
- causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length)
198
- for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
 
 
 
 
 
 
 
 
 
 
 
199
  if output_hidden_states:
200
  hst = (hidden_states,)
201
  all_hidden_states = all_hidden_states + hst
202
  if self.gradient_checkpointing and self.training:
203
  if use_cache:
204
- logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
 
 
205
  use_cache = False
206
 
207
  def create_custom_forward(module):
208
-
209
  def custom_forward(*inputs):
210
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
 
 
 
 
 
211
  return custom_forward
212
- outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
 
 
 
 
 
 
 
213
  else:
214
- outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)
 
 
 
 
 
 
 
 
215
  hidden_states = outputs[0]
216
  if use_cache is True:
217
  presents = presents + (outputs[1],)
@@ -223,21 +405,77 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
223
  hst = (hidden_states,)
224
  all_hidden_states = all_hidden_states + hst
225
  if not return_dict:
226
- return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
227
- return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
228
- setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
229
- setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
230
- setattr(model.transformer, 'forward', MethodType(forward, model.transformer))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
232
 
233
- def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  """Replacement forward method for BloomCausalLM."""
235
- if deprecated_arguments.pop('position_ids', False) is not False:
236
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
 
 
 
 
237
  if len(deprecated_arguments) > 0:
238
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
239
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
240
- transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  hidden_states = transformer_outputs[0]
242
  lm_logits = self.lm_head(hidden_states)
243
  loss = None
@@ -246,13 +484,28 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
246
  shift_labels = labels[..., 1:].contiguous()
247
  (batch_size, seq_length, vocab_size) = shift_logits.shape
248
  loss_fct = CrossEntropyLoss()
249
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
 
 
 
250
  if not return_dict:
251
  output = (lm_logits,) + transformer_outputs[1:]
252
  return (loss,) + output if loss is not None else output
253
- return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
254
-
255
- def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
256
  if past:
257
  input_ids = input_ids[:, -1].unsqueeze(-1)
258
  bidirectional_mask = None
@@ -260,12 +513,24 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
260
  past = self._convert_to_bloom_cache(past)
261
  else:
262
  bidirectional_mask = torch.ones_like(input_ids)
263
- return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask}
264
- setattr(model, 'forward', MethodType(forward, model))
265
- setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model))
266
- setattr(model, '_prefix_lm_converted', True)
 
 
 
 
 
 
 
 
 
 
 
267
  return model
268
 
 
269
  def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
270
  """Converts an OPT Causal LM to a Prefix LM.
271
 
@@ -274,36 +539,89 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
274
 
275
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
276
  """
277
- if hasattr(model, '_prefix_lm_converted'):
278
  return model
279
  assert isinstance(model, OPTForCausalLM)
280
- assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'
281
- setattr(model, '_original_forward', getattr(model, 'forward'))
282
- setattr(model, '_original_generate', getattr(model, 'generate'))
 
 
283
  model.model.decoder.bidirectional_mask = None
284
 
285
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
 
 
286
  combined_attention_mask = None
287
  if input_shape[-1] > 1:
288
- if self.bidirectional_mask == 'g':
289
  (bsz, src_length) = input_shape
290
- combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
 
 
 
 
291
  else:
292
- combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
 
 
 
 
293
  if self.bidirectional_mask is not None:
294
  assert attention_mask.shape == self.bidirectional_mask.shape
295
- expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
296
- combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
 
 
 
 
 
 
297
  if attention_mask is not None:
298
- expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
299
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
 
 
 
 
 
 
300
  return combined_attention_mask
301
- setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
302
-
303
- def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  def call_og_forward():
306
- return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
 
307
  if bidirectional_mask is None:
308
  return call_og_forward()
309
  self.model.decoder.bidirectional_mask = bidirectional_mask
@@ -317,7 +635,7 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
317
 
318
  def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
319
  """Wraps original generate to enable PrefixLM-style attention."""
320
- self.model.decoder.bidirectional_mask = 'g'
321
  try:
322
  output = self._original_generate(*args, **kwargs)
323
  except:
@@ -325,12 +643,23 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
325
  raise
326
  self.model.decoder.bidirectional_mask = None
327
  return output
328
- setattr(model, 'forward', MethodType(forward, model))
329
- setattr(model, 'generate', MethodType(generate, model))
330
- setattr(model, '_prefix_lm_converted', True)
 
331
  return model
 
 
332
  _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
333
- CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
 
 
 
 
 
 
 
 
334
 
335
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
336
  """Converts a HuggingFace Causal LM to a Prefix LM.
@@ -396,7 +725,12 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
396
  elif isinstance(model, OPTForCausalLM):
397
  return _convert_opt_causal_lm_to_prefix_lm(model)
398
  else:
399
- raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
 
 
 
 
 
400
 
401
  def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
402
  """Attempts to add bidirectional_mask to batch if missing.
@@ -404,12 +738,16 @@ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
404
  Raises:
405
  KeyError if bidirectional_mask is missing and can't be inferred
406
  """
407
- if 'bidirectional_mask' not in batch:
408
- if batch.get('mode', None) == 'icl_task':
409
- batch['bidirectional_mask'] = batch['attention_mask'].clone()
410
- for (i, continuation_indices) in enumerate(batch['continuation_indices']):
411
- batch['bidirectional_mask'][i, continuation_indices] = 0
412
- elif 'labels' in batch and 'attention_mask' in batch:
413
- batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask'])
 
 
414
  else:
415
- raise KeyError('No bidirectional_mask in batch and not sure how to construct one.')
 
 
 
11
  from types import MethodType
12
  from typing import Any, Dict, List, Optional, Tuple, Union
13
  import torch
14
+ from transformers.models.bloom.modeling_bloom import (
15
+ BaseModelOutputWithPastAndCrossAttentions,
16
+ BloomForCausalLM,
17
+ BloomModel,
18
+ CausalLMOutputWithCrossAttentions,
19
+ CrossEntropyLoss,
20
+ )
21
  from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
22
+ from transformers.models.bloom.modeling_bloom import (
23
+ _make_causal_mask as _make_causal_mask_bloom,
24
+ )
25
  from transformers.models.bloom.modeling_bloom import logging
26
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
27
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
 
29
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
30
  from transformers.models.opt.modeling_opt import OPTForCausalLM
31
  from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
32
+ from transformers.models.opt.modeling_opt import (
33
+ _make_causal_mask as _make_causal_mask_opt,
34
+ )
35
+
36
  logger = logging.get_logger(__name__)
37
+ _SUPPORTED_GPT_MODELS = (
38
+ GPT2LMHeadModel,
39
+ GPTJForCausalLM,
40
+ GPTNeoForCausalLM,
41
+ GPTNeoXForCausalLM,
42
+ )
43
+ CAUSAL_GPT_TYPES = Union[
44
+ GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM
45
+ ]
46
+
47
 
48
  def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
49
  """Converts a GPT-style Causal LM to a Prefix LM.
 
56
 
57
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
58
  """
59
+ if hasattr(model, "_prefix_lm_converted"):
60
  return model
61
  assert isinstance(model, _SUPPORTED_GPT_MODELS)
62
+ assert (
63
+ model.config.add_cross_attention == False
64
+ ), "Only supports GPT-style decoder-only models"
65
 
66
  def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
67
  """Helper that gets a list of the model's attention modules.
 
77
  blocks = model.transformer.h
78
  for block in blocks:
79
  if isinstance(model, GPTNeoForCausalLM):
80
+ if block.attn.attention_type != "global":
81
  continue
82
  attn_module = block.attn.attention
83
  elif isinstance(model, GPTNeoXForCausalLM):
 
86
  attn_module = block.attn
87
  attn_modules.append(attn_module)
88
  return attn_modules
 
 
89
 
90
+ setattr(model, "_original_forward", getattr(model, "forward"))
91
+ setattr(model, "_original_generate", getattr(model, "generate"))
92
+
93
+ def forward(
94
+ self: CAUSAL_GPT_TYPES,
95
+ input_ids: Optional[torch.LongTensor] = None,
96
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
97
+ attention_mask: Optional[torch.FloatTensor] = None,
98
+ bidirectional_mask: Optional[torch.Tensor] = None,
99
+ token_type_ids: Optional[torch.LongTensor] = None,
100
+ position_ids: Optional[torch.LongTensor] = None,
101
+ head_mask: Optional[torch.FloatTensor] = None,
102
+ inputs_embeds: Optional[torch.FloatTensor] = None,
103
+ labels: Optional[torch.LongTensor] = None,
104
+ use_cache: Optional[bool] = None,
105
+ output_attentions: Optional[bool] = None,
106
+ output_hidden_states: Optional[bool] = None,
107
+ return_dict: Optional[bool] = None,
108
+ ):
109
  """Wraps original forward to enable PrefixLM attention."""
110
 
111
  def call_og_forward():
112
  if isinstance(self, GPTNeoXForCausalLM):
113
+ return self._original_forward(
114
+ input_ids=input_ids,
115
+ past_key_values=past_key_values,
116
+ attention_mask=attention_mask,
117
+ head_mask=head_mask,
118
+ inputs_embeds=inputs_embeds,
119
+ labels=labels,
120
+ use_cache=use_cache,
121
+ output_attentions=output_attentions,
122
+ output_hidden_states=output_hidden_states,
123
+ return_dict=return_dict,
124
+ )
125
  else:
126
+ return self._original_forward(
127
+ input_ids=input_ids,
128
+ past_key_values=past_key_values,
129
+ attention_mask=attention_mask,
130
+ token_type_ids=token_type_ids,
131
+ position_ids=position_ids,
132
+ head_mask=head_mask,
133
+ inputs_embeds=inputs_embeds,
134
+ labels=labels,
135
+ use_cache=use_cache,
136
+ output_attentions=output_attentions,
137
+ output_hidden_states=output_hidden_states,
138
+ return_dict=return_dict,
139
+ )
140
+
141
  if bidirectional_mask is None:
142
  return call_og_forward()
143
  assert isinstance(bidirectional_mask, torch.Tensor)
 
145
  (b, s) = bidirectional_mask.shape
146
  max_length = attn_modules[0].bias.shape[-1]
147
  if s > max_length:
148
+ raise ValueError(
149
+ f"bidirectional_mask sequence length (={s}) exceeds the "
150
+ + f"max length allowed by the model ({max_length})."
151
+ )
152
  assert s <= max_length
153
  if s < max_length:
154
+ pad = torch.zeros(
155
+ (int(b), int(max_length - s)),
156
+ dtype=bidirectional_mask.dtype,
157
+ device=bidirectional_mask.device,
158
+ )
159
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
160
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
161
  for attn_module in attn_modules:
162
+ attn_module.bias.data = torch.logical_or(
163
+ attn_module.bias.data, bidirectional
164
+ )
165
  output = call_og_forward()
166
  for attn_module in attn_modules:
167
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
 
176
  for attn_module in attn_modules:
177
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
178
  return output
179
+
180
+ setattr(model, "forward", MethodType(forward, model))
181
+ setattr(model, "generate", MethodType(generate, model))
182
+ setattr(model, "_prefix_lm_converted", True)
183
  return model
184
 
185
+
186
  def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
187
  """Converts a BLOOM Causal LM to a Prefix LM.
188
 
 
191
 
192
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
193
  """
194
+ if hasattr(model, "_prefix_lm_converted"):
195
  return model
196
  assert isinstance(model, BloomForCausalLM)
197
+ assert (
198
+ model.config.add_cross_attention == False
199
+ ), "Only supports BLOOM decoder-only models"
200
+
201
+ def _prepare_attn_mask(
202
+ self: BloomModel,
203
+ attention_mask: torch.Tensor,
204
+ bidirectional_mask: Optional[torch.Tensor],
205
+ input_shape: Tuple[int, int],
206
+ past_key_values_length: int,
207
+ ) -> torch.BoolTensor:
208
  combined_attention_mask = None
209
  device = attention_mask.device
210
  (_, src_length) = input_shape
211
  if src_length > 1:
212
+ combined_attention_mask = _make_causal_mask_bloom(
213
+ input_shape,
214
+ device=device,
215
+ past_key_values_length=past_key_values_length,
216
+ )
217
  if bidirectional_mask is not None:
218
  assert attention_mask.shape == bidirectional_mask.shape
219
+ expanded_bidirectional_mask = _expand_mask_bloom(
220
+ bidirectional_mask, tgt_length=src_length
221
+ )
222
+ combined_attention_mask = torch.logical_and(
223
+ combined_attention_mask, expanded_bidirectional_mask
224
+ )
225
  expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
226
+ combined_attention_mask = (
227
+ expanded_attn_mask
228
+ if combined_attention_mask is None
229
+ else expanded_attn_mask | combined_attention_mask
230
+ )
231
  return combined_attention_mask
232
 
233
+ def _build_alibi_tensor(
234
+ self: BloomModel,
235
+ batch_size: int,
236
+ query_length: int,
237
+ key_length: int,
238
+ dtype: torch.dtype,
239
+ device: torch.device,
240
+ ) -> torch.Tensor:
241
  num_heads = self.config.n_head
242
  closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
243
+ base = torch.tensor(
244
+ 2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))),
245
+ device=device,
246
+ dtype=torch.float32,
247
+ )
248
+ powers = torch.arange(
249
+ 1, 1 + closest_power_of_2, device=device, dtype=torch.int32
250
+ )
251
  slopes = torch.pow(base, powers)
252
  if closest_power_of_2 != num_heads:
253
+ extra_base = torch.tensor(
254
+ 2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))),
255
+ device=device,
256
+ dtype=torch.float32,
257
+ )
258
+ num_remaining_heads = min(
259
+ closest_power_of_2, num_heads - closest_power_of_2
260
+ )
261
+ extra_powers = torch.arange(
262
+ 1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32
263
+ )
264
  slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
265
  qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
266
  ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
267
  diffs = qa - ka + key_length - query_length
268
  diffs = -diffs.abs()
269
+ alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(
270
+ 1, 1, query_length, key_length
271
+ )
272
+ alibi = alibi.expand(batch_size, -1, -1, -1).reshape(
273
+ -1, query_length, key_length
274
+ )
275
  return alibi.to(dtype)
276
+
277
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
278
 
279
+ def forward(
280
+ self: BloomModel,
281
+ input_ids: Optional[torch.LongTensor] = None,
282
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ bidirectional_mask: Optional[torch.Tensor] = None,
285
+ head_mask: Optional[torch.LongTensor] = None,
286
+ inputs_embeds: Optional[torch.LongTensor] = None,
287
+ use_cache: Optional[bool] = None,
288
+ output_attentions: Optional[bool] = None,
289
+ output_hidden_states: Optional[bool] = None,
290
+ return_dict: Optional[bool] = None,
291
+ **deprecated_arguments,
292
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
293
+ if deprecated_arguments.pop("position_ids", False) is not False:
294
+ warnings.warn(
295
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. "
296
+ + "You can safely ignore passing `position_ids`.",
297
+ FutureWarning,
298
+ )
299
  if len(deprecated_arguments) > 0:
300
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
301
+ output_attentions = (
302
+ output_attentions
303
+ if output_attentions is not None
304
+ else self.config.output_attentions
305
+ )
306
+ output_hidden_states = (
307
+ output_hidden_states
308
+ if output_hidden_states is not None
309
+ else self.config.output_hidden_states
310
+ )
311
  use_cache = use_cache if use_cache is not None else self.config.use_cache
312
+ return_dict = (
313
+ return_dict if return_dict is not None else self.config.use_return_dict
314
+ )
315
  if input_ids is not None and inputs_embeds is not None:
316
+ raise ValueError(
317
+ "You cannot specify both input_ids and inputs_embeds at the same time"
318
+ )
319
  elif input_ids is not None:
320
  (batch_size, seq_length) = input_ids.shape
321
  elif inputs_embeds is not None:
322
  (batch_size, seq_length, _) = inputs_embeds.shape
323
  else:
324
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
325
  if past_key_values is None:
326
  past_key_values = tuple([None] * len(self.h))
327
  head_mask = self.get_head_mask(head_mask, self.config.n_layer)
 
338
  past_key_values_length = tmp.shape[2]
339
  seq_length_with_past = seq_length_with_past + past_key_values_length
340
  if attention_mask is None:
341
+ attention_mask = torch.ones(
342
+ (batch_size, seq_length_with_past), device=hidden_states.device
343
+ )
344
  else:
345
  attention_mask = attention_mask.to(hidden_states.device)
346
+ alibi = self._build_alibi_tensor(
347
+ batch_size=batch_size,
348
+ query_length=seq_length,
349
+ key_length=seq_length_with_past,
350
+ dtype=hidden_states.dtype,
351
+ device=hidden_states.device,
352
+ )
353
+ causal_mask = self._prepare_attn_mask(
354
+ attention_mask,
355
+ bidirectional_mask,
356
+ input_shape=(batch_size, seq_length),
357
+ past_key_values_length=past_key_values_length,
358
+ )
359
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
360
  if output_hidden_states:
361
  hst = (hidden_states,)
362
  all_hidden_states = all_hidden_states + hst
363
  if self.gradient_checkpointing and self.training:
364
  if use_cache:
365
+ logger.warning(
366
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
367
+ )
368
  use_cache = False
369
 
370
  def create_custom_forward(module):
 
371
  def custom_forward(*inputs):
372
+ return module(
373
+ *inputs,
374
+ use_cache=use_cache,
375
+ output_attentions=output_attentions,
376
+ )
377
+
378
  return custom_forward
379
+
380
+ outputs = torch.utils.checkpoint.checkpoint(
381
+ create_custom_forward(block),
382
+ hidden_states,
383
+ alibi,
384
+ causal_mask,
385
+ head_mask[i],
386
+ )
387
  else:
388
+ outputs = block(
389
+ hidden_states,
390
+ layer_past=layer_past,
391
+ attention_mask=causal_mask,
392
+ head_mask=head_mask[i],
393
+ use_cache=use_cache,
394
+ output_attentions=output_attentions,
395
+ alibi=alibi,
396
+ )
397
  hidden_states = outputs[0]
398
  if use_cache is True:
399
  presents = presents + (outputs[1],)
 
405
  hst = (hidden_states,)
406
  all_hidden_states = all_hidden_states + hst
407
  if not return_dict:
408
+ return tuple(
409
+ (
410
+ v
411
+ for v in [
412
+ hidden_states,
413
+ presents,
414
+ all_hidden_states,
415
+ all_self_attentions,
416
+ ]
417
+ if v is not None
418
+ )
419
+ )
420
+ return BaseModelOutputWithPastAndCrossAttentions(
421
+ last_hidden_state=hidden_states,
422
+ past_key_values=presents,
423
+ hidden_states=all_hidden_states,
424
+ attentions=all_self_attentions,
425
+ )
426
+
427
+ setattr(
428
+ model.transformer,
429
+ "_prepare_attn_mask",
430
+ MethodType(_prepare_attn_mask, model.transformer),
431
+ )
432
+ setattr(
433
+ model.transformer,
434
+ "_build_alibi_tensor",
435
+ MethodType(_build_alibi_tensor, model.transformer),
436
+ )
437
+ setattr(model.transformer, "forward", MethodType(forward, model.transformer))
438
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
439
 
440
+ def forward(
441
+ self: BloomForCausalLM,
442
+ input_ids: Optional[torch.LongTensor] = None,
443
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
444
+ attention_mask: Optional[torch.Tensor] = None,
445
+ bidirectional_mask: Optional[torch.Tensor] = None,
446
+ head_mask: Optional[torch.Tensor] = None,
447
+ inputs_embeds: Optional[torch.Tensor] = None,
448
+ labels: Optional[torch.Tensor] = None,
449
+ use_cache: Optional[bool] = None,
450
+ output_attentions: Optional[bool] = None,
451
+ output_hidden_states: Optional[bool] = None,
452
+ return_dict: Optional[bool] = None,
453
+ **deprecated_arguments,
454
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
455
  """Replacement forward method for BloomCausalLM."""
456
+ if deprecated_arguments.pop("position_ids", False) is not False:
457
+ warnings.warn(
458
+ "`position_ids` have no functionality in BLOOM and will be removed "
459
+ + "in v5.0.0. You can safely ignore passing `position_ids`.",
460
+ FutureWarning,
461
+ )
462
  if len(deprecated_arguments) > 0:
463
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
464
+ return_dict = (
465
+ return_dict if return_dict is not None else self.config.use_return_dict
466
+ )
467
+ transformer_outputs = self.transformer(
468
+ input_ids,
469
+ past_key_values=past_key_values,
470
+ attention_mask=attention_mask,
471
+ bidirectional_mask=bidirectional_mask,
472
+ head_mask=head_mask,
473
+ inputs_embeds=inputs_embeds,
474
+ use_cache=use_cache,
475
+ output_attentions=output_attentions,
476
+ output_hidden_states=output_hidden_states,
477
+ return_dict=return_dict,
478
+ )
479
  hidden_states = transformer_outputs[0]
480
  lm_logits = self.lm_head(hidden_states)
481
  loss = None
 
484
  shift_labels = labels[..., 1:].contiguous()
485
  (batch_size, seq_length, vocab_size) = shift_logits.shape
486
  loss_fct = CrossEntropyLoss()
487
+ loss = loss_fct(
488
+ shift_logits.view(batch_size * seq_length, vocab_size),
489
+ shift_labels.view(batch_size * seq_length),
490
+ )
491
  if not return_dict:
492
  output = (lm_logits,) + transformer_outputs[1:]
493
  return (loss,) + output if loss is not None else output
494
+ return CausalLMOutputWithCrossAttentions(
495
+ loss=loss,
496
+ logits=lm_logits,
497
+ past_key_values=transformer_outputs.past_key_values,
498
+ hidden_states=transformer_outputs.hidden_states,
499
+ attentions=transformer_outputs.attentions,
500
+ )
501
+
502
+ def prepare_inputs_for_generation(
503
+ self: BloomForCausalLM,
504
+ input_ids: torch.LongTensor,
505
+ past: Optional[torch.Tensor] = None,
506
+ attention_mask: Optional[torch.Tensor] = None,
507
+ **kwargs,
508
+ ) -> dict:
509
  if past:
510
  input_ids = input_ids[:, -1].unsqueeze(-1)
511
  bidirectional_mask = None
 
513
  past = self._convert_to_bloom_cache(past)
514
  else:
515
  bidirectional_mask = torch.ones_like(input_ids)
516
+ return {
517
+ "input_ids": input_ids,
518
+ "past_key_values": past,
519
+ "use_cache": True,
520
+ "attention_mask": attention_mask,
521
+ "bidirectional_mask": bidirectional_mask,
522
+ }
523
+
524
+ setattr(model, "forward", MethodType(forward, model))
525
+ setattr(
526
+ model,
527
+ "prepare_inputs_for_generation",
528
+ MethodType(prepare_inputs_for_generation, model),
529
+ )
530
+ setattr(model, "_prefix_lm_converted", True)
531
  return model
532
 
533
+
534
  def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
535
  """Converts an OPT Causal LM to a Prefix LM.
536
 
 
539
 
540
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
541
  """
542
+ if hasattr(model, "_prefix_lm_converted"):
543
  return model
544
  assert isinstance(model, OPTForCausalLM)
545
+ assert (
546
+ model.config.add_cross_attention == False
547
+ ), "Only supports OPT decoder-only models"
548
+ setattr(model, "_original_forward", getattr(model, "forward"))
549
+ setattr(model, "_original_generate", getattr(model, "generate"))
550
  model.model.decoder.bidirectional_mask = None
551
 
552
+ def _prepare_decoder_attention_mask(
553
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
554
+ ):
555
  combined_attention_mask = None
556
  if input_shape[-1] > 1:
557
+ if self.bidirectional_mask == "g":
558
  (bsz, src_length) = input_shape
559
+ combined_attention_mask = torch.zeros(
560
+ (bsz, 1, src_length, src_length + past_key_values_length),
561
+ dtype=inputs_embeds.dtype,
562
+ device=inputs_embeds.device,
563
+ )
564
  else:
565
+ combined_attention_mask = _make_causal_mask_opt(
566
+ input_shape,
567
+ inputs_embeds.dtype,
568
+ past_key_values_length=past_key_values_length,
569
+ ).to(inputs_embeds.device)
570
  if self.bidirectional_mask is not None:
571
  assert attention_mask.shape == self.bidirectional_mask.shape
572
+ expanded_bidirectional_mask = _expand_mask_opt(
573
+ self.bidirectional_mask,
574
+ inputs_embeds.dtype,
575
+ tgt_len=input_shape[-1],
576
+ ).to(inputs_embeds.device)
577
+ combined_attention_mask = torch.maximum(
578
+ expanded_bidirectional_mask, combined_attention_mask
579
+ )
580
  if attention_mask is not None:
581
+ expanded_attn_mask = _expand_mask_opt(
582
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
583
+ ).to(inputs_embeds.device)
584
+ combined_attention_mask = (
585
+ expanded_attn_mask
586
+ if combined_attention_mask is None
587
+ else expanded_attn_mask + combined_attention_mask
588
+ )
589
  return combined_attention_mask
 
 
 
590
 
591
+ setattr(
592
+ model.model.decoder,
593
+ "_prepare_decoder_attention_mask",
594
+ MethodType(_prepare_decoder_attention_mask, model.model.decoder),
595
+ )
596
+
597
+ def forward(
598
+ self: OPTForCausalLM,
599
+ input_ids: Optional[torch.LongTensor] = None,
600
+ attention_mask: Optional[torch.Tensor] = None,
601
+ bidirectional_mask: Optional[torch.ByteTensor] = None,
602
+ head_mask: Optional[torch.Tensor] = None,
603
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
604
+ inputs_embeds: Optional[torch.FloatTensor] = None,
605
+ labels: Optional[torch.LongTensor] = None,
606
+ use_cache: Optional[bool] = None,
607
+ output_attentions: Optional[bool] = None,
608
+ output_hidden_states: Optional[bool] = None,
609
+ return_dict: Optional[bool] = None,
610
+ ):
611
  def call_og_forward():
612
+ return self._original_forward(
613
+ input_ids=input_ids,
614
+ attention_mask=attention_mask,
615
+ head_mask=head_mask,
616
+ past_key_values=past_key_values,
617
+ inputs_embeds=inputs_embeds,
618
+ labels=labels,
619
+ use_cache=use_cache,
620
+ output_attentions=output_attentions,
621
+ output_hidden_states=output_hidden_states,
622
+ return_dict=return_dict,
623
+ )
624
+
625
  if bidirectional_mask is None:
626
  return call_og_forward()
627
  self.model.decoder.bidirectional_mask = bidirectional_mask
 
635
 
636
  def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
637
  """Wraps original generate to enable PrefixLM-style attention."""
638
+ self.model.decoder.bidirectional_mask = "g"
639
  try:
640
  output = self._original_generate(*args, **kwargs)
641
  except:
 
643
  raise
644
  self.model.decoder.bidirectional_mask = None
645
  return output
646
+
647
+ setattr(model, "forward", MethodType(forward, model))
648
+ setattr(model, "generate", MethodType(generate, model))
649
+ setattr(model, "_prefix_lm_converted", True)
650
  return model
651
+
652
+
653
  _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
654
+ CAUSAL_LM_TYPES = Union[
655
+ GPT2LMHeadModel,
656
+ GPTJForCausalLM,
657
+ GPTNeoForCausalLM,
658
+ GPTNeoXForCausalLM,
659
+ BloomForCausalLM,
660
+ OPTForCausalLM,
661
+ ]
662
+
663
 
664
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
665
  """Converts a HuggingFace Causal LM to a Prefix LM.
 
725
  elif isinstance(model, OPTForCausalLM):
726
  return _convert_opt_causal_lm_to_prefix_lm(model)
727
  else:
728
+ raise TypeError(
729
+ f"Cannot convert model to Prefix LM. "
730
+ + f"Model does not belong to set of supported HF models:"
731
+ + f"\n{_SUPPORTED_HF_MODELS}"
732
+ )
733
+
734
 
735
  def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
736
  """Attempts to add bidirectional_mask to batch if missing.
 
738
  Raises:
739
  KeyError if bidirectional_mask is missing and can't be inferred
740
  """
741
+ if "bidirectional_mask" not in batch:
742
+ if batch.get("mode", None) == "icl_task":
743
+ batch["bidirectional_mask"] = batch["attention_mask"].clone()
744
+ for i, continuation_indices in enumerate(batch["continuation_indices"]):
745
+ batch["bidirectional_mask"][i, continuation_indices] = 0
746
+ elif "labels" in batch and "attention_mask" in batch:
747
+ batch["bidirectional_mask"] = torch.logical_and(
748
+ torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)
749
+ ).type_as(batch["attention_mask"])
750
  else:
751
+ raise KeyError(
752
+ "No bidirectional_mask in batch and not sure how to construct one."
753
+ )
meta_init_context.py CHANGED
@@ -2,8 +2,9 @@ from contextlib import contextmanager
2
  import torch
3
  import torch.nn as nn
4
 
 
5
  @contextmanager
6
- def init_empty_weights(include_buffers: bool=False):
7
  """Meta initialization context manager.
8
 
9
  A context manager under which models are initialized with all parameters
@@ -30,11 +31,12 @@ def init_empty_weights(include_buffers: bool=False):
30
 
31
  </Tip>
32
  """
33
- with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:
34
  yield f
35
 
 
36
  @contextmanager
37
- def init_on_device(device: torch.device, include_buffers: bool=False):
38
  """Device initialization context manager.
39
 
40
  A context manager under which models are initialized with all parameters
@@ -62,33 +64,47 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
62
  if param is not None:
63
  param_cls = type(module._parameters[name])
64
  kwargs = module._parameters[name].__dict__
65
- module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
 
 
66
 
67
  def register_empty_buffer(module, name, buffer):
68
  old_register_buffer(module, name, buffer)
69
  if buffer is not None:
70
  module._buffers[name] = module._buffers[name].to(device)
 
71
  if include_buffers:
72
- tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
 
 
 
73
  else:
74
  tensor_constructors_to_patch = {}
75
 
76
  def patch_tensor_constructor(fn):
77
-
78
  def wrapper(*args, **kwargs):
79
- kwargs['device'] = device
80
  return fn(*args, **kwargs)
 
81
  return wrapper
 
82
  try:
83
  nn.Module.register_parameter = register_empty_parameter
84
  if include_buffers:
85
  nn.Module.register_buffer = register_empty_buffer
86
  for torch_function_name in tensor_constructors_to_patch.keys():
87
- setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
 
 
 
 
88
  yield
89
  finally:
90
  nn.Module.register_parameter = old_register_parameter
91
  if include_buffers:
92
  nn.Module.register_buffer = old_register_buffer
93
- for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
94
- setattr(torch, torch_function_name, old_torch_function)
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
 
5
+
6
  @contextmanager
7
+ def init_empty_weights(include_buffers: bool = False):
8
  """Meta initialization context manager.
9
 
10
  A context manager under which models are initialized with all parameters
 
31
 
32
  </Tip>
33
  """
34
+ with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
35
  yield f
36
 
37
+
38
  @contextmanager
39
+ def init_on_device(device: torch.device, include_buffers: bool = False):
40
  """Device initialization context manager.
41
 
42
  A context manager under which models are initialized with all parameters
 
64
  if param is not None:
65
  param_cls = type(module._parameters[name])
66
  kwargs = module._parameters[name].__dict__
67
+ module._parameters[name] = param_cls(
68
+ module._parameters[name].to(device), **kwargs
69
+ )
70
 
71
  def register_empty_buffer(module, name, buffer):
72
  old_register_buffer(module, name, buffer)
73
  if buffer is not None:
74
  module._buffers[name] = module._buffers[name].to(device)
75
+
76
  if include_buffers:
77
+ tensor_constructors_to_patch = {
78
+ torch_function_name: getattr(torch, torch_function_name)
79
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
80
+ }
81
  else:
82
  tensor_constructors_to_patch = {}
83
 
84
  def patch_tensor_constructor(fn):
 
85
  def wrapper(*args, **kwargs):
86
+ kwargs["device"] = device
87
  return fn(*args, **kwargs)
88
+
89
  return wrapper
90
+
91
  try:
92
  nn.Module.register_parameter = register_empty_parameter
93
  if include_buffers:
94
  nn.Module.register_buffer = register_empty_buffer
95
  for torch_function_name in tensor_constructors_to_patch.keys():
96
+ setattr(
97
+ torch,
98
+ torch_function_name,
99
+ patch_tensor_constructor(getattr(torch, torch_function_name)),
100
+ )
101
  yield
102
  finally:
103
  nn.Module.register_parameter = old_register_parameter
104
  if include_buffers:
105
  nn.Module.register_buffer = old_register_buffer
106
+ for (
107
+ torch_function_name,
108
+ old_torch_function,
109
+ ) in tensor_constructors_to_patch.items():
110
+ setattr(torch, torch_function_name, old_torch_function)
norm.py CHANGED
@@ -1,28 +1,55 @@
1
  import torch
2
 
 
3
  def _cast_if_autocast_enabled(tensor):
4
  if torch.is_autocast_enabled():
5
- if tensor.device.type == 'cuda':
6
  dtype = torch.get_autocast_gpu_dtype()
7
- elif tensor.device.type == 'cpu':
8
  dtype = torch.get_autocast_cpu_dtype()
9
  else:
10
  raise NotImplementedError()
11
  return tensor.to(dtype=dtype)
12
  return tensor
13
 
14
- class LPLayerNorm(torch.nn.LayerNorm):
15
 
16
- def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
17
- super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def forward(self, x):
20
  module_device = x.device
21
  downcast_x = _cast_if_autocast_enabled(x)
22
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
23
- downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
 
 
 
 
 
 
24
  with torch.autocast(enabled=False, device_type=module_device.type):
25
- return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
 
 
 
 
 
 
 
26
 
27
  def rms_norm(x, weight=None, eps=1e-05):
28
  output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
@@ -30,27 +57,50 @@ def rms_norm(x, weight=None, eps=1e-05):
30
  return output * weight
31
  return output
32
 
33
- class RMSNorm(torch.nn.Module):
34
 
35
- def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
 
 
 
36
  super().__init__()
37
  self.eps = eps
38
  if weight:
39
- self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
 
 
40
  else:
41
- self.register_parameter('weight', None)
42
 
43
  def forward(self, x):
44
  return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
45
 
46
- class LPRMSNorm(RMSNorm):
47
 
48
- def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
49
- super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
 
 
 
 
 
 
 
 
 
50
 
51
  def forward(self, x):
52
  downcast_x = _cast_if_autocast_enabled(x)
53
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
 
 
 
 
54
  with torch.autocast(enabled=False, device_type=x.device.type):
55
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56
- NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
 
 
 
 
 
 
 
 
1
  import torch
2
 
3
+
4
  def _cast_if_autocast_enabled(tensor):
5
  if torch.is_autocast_enabled():
6
+ if tensor.device.type == "cuda":
7
  dtype = torch.get_autocast_gpu_dtype()
8
+ elif tensor.device.type == "cpu":
9
  dtype = torch.get_autocast_cpu_dtype()
10
  else:
11
  raise NotImplementedError()
12
  return tensor.to(dtype=dtype)
13
  return tensor
14
 
 
15
 
16
+ class LPLayerNorm(torch.nn.LayerNorm):
17
+ def __init__(
18
+ self,
19
+ normalized_shape,
20
+ eps=1e-05,
21
+ elementwise_affine=True,
22
+ device=None,
23
+ dtype=None,
24
+ ):
25
+ super().__init__(
26
+ normalized_shape=normalized_shape,
27
+ eps=eps,
28
+ elementwise_affine=elementwise_affine,
29
+ device=device,
30
+ dtype=dtype,
31
+ )
32
 
33
  def forward(self, x):
34
  module_device = x.device
35
  downcast_x = _cast_if_autocast_enabled(x)
36
+ downcast_weight = (
37
+ _cast_if_autocast_enabled(self.weight)
38
+ if self.weight is not None
39
+ else self.weight
40
+ )
41
+ downcast_bias = (
42
+ _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
43
+ )
44
  with torch.autocast(enabled=False, device_type=module_device.type):
45
+ return torch.nn.functional.layer_norm(
46
+ downcast_x,
47
+ self.normalized_shape,
48
+ downcast_weight,
49
+ downcast_bias,
50
+ self.eps,
51
+ )
52
+
53
 
54
  def rms_norm(x, weight=None, eps=1e-05):
55
  output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
 
57
  return output * weight
58
  return output
59
 
 
60
 
61
+ class RMSNorm(torch.nn.Module):
62
+ def __init__(
63
+ self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
64
+ ):
65
  super().__init__()
66
  self.eps = eps
67
  if weight:
68
+ self.weight = torch.nn.Parameter(
69
+ torch.ones(normalized_shape, dtype=dtype, device=device)
70
+ )
71
  else:
72
+ self.register_parameter("weight", None)
73
 
74
  def forward(self, x):
75
  return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
76
 
 
77
 
78
+ class LPRMSNorm(RMSNorm):
79
+ def __init__(
80
+ self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
81
+ ):
82
+ super().__init__(
83
+ normalized_shape=normalized_shape,
84
+ eps=eps,
85
+ weight=weight,
86
+ dtype=dtype,
87
+ device=device,
88
+ )
89
 
90
  def forward(self, x):
91
  downcast_x = _cast_if_autocast_enabled(x)
92
+ downcast_weight = (
93
+ _cast_if_autocast_enabled(self.weight)
94
+ if self.weight is not None
95
+ else self.weight
96
+ )
97
  with torch.autocast(enabled=False, device_type=x.device.type):
98
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
99
+
100
+
101
+ NORM_CLASS_REGISTRY = {
102
+ "layernorm": torch.nn.LayerNorm,
103
+ "low_precision_layernorm": LPLayerNorm,
104
+ "rmsnorm": RMSNorm,
105
+ "low_precision_rmsnorm": LPRMSNorm,
106
+ }
param_init_fns.py CHANGED
@@ -7,97 +7,133 @@ import torch
7
  from torch import nn
8
  from .norm import NORM_CLASS_REGISTRY
9
 
10
- def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
 
11
  del kwargs
12
  if verbose > 1:
13
  warnings.warn(f"Initializing network using module's reset_parameters attribute")
14
- if hasattr(module, 'reset_parameters'):
15
  module.reset_parameters()
16
 
 
17
  def fused_init_helper_(module: nn.Module, init_fn_):
18
- _fused = getattr(module, '_fused', None)
19
  if _fused is None:
20
- raise RuntimeError(f'Internal logic error')
21
  (dim, splits) = _fused
22
  splits = (0, *splits, module.weight.size(dim))
23
- for (s, e) in zip(splits[:-1], splits[1:]):
24
  slice_indices = [slice(None)] * module.weight.ndim
25
  slice_indices[dim] = slice(s, e)
26
  init_fn_(module.weight[slice_indices])
27
 
28
- def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
29
  del kwargs
30
  if verbose > 1:
31
- warnings.warn(f'If model has bias parameters they are initialized to 0.')
32
  init_div_is_residual = init_div_is_residual
33
  if init_div_is_residual is False:
34
  div_is_residual = 1.0
35
  elif init_div_is_residual is True:
36
  div_is_residual = math.sqrt(2 * n_layers)
37
- elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
 
 
38
  div_is_residual = init_div_is_residual
39
  elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
40
  div_is_residual = float(init_div_is_residual)
41
  else:
42
  div_is_residual = 1.0
43
- raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
 
 
44
  if init_div_is_residual is not False:
45
  if verbose > 1:
46
- warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
 
 
 
47
  if isinstance(module, nn.Linear):
48
- if hasattr(module, '_fused'):
49
  fused_init_helper_(module, init_fn_)
50
  else:
51
  init_fn_(module.weight)
52
  if module.bias is not None:
53
  torch.nn.init.zeros_(module.bias)
54
- if init_div_is_residual is not False and getattr(module, '_is_residual', False):
55
  with torch.no_grad():
56
  module.weight.div_(div_is_residual)
57
  elif isinstance(module, nn.Embedding):
58
  if emb_init_std is not None:
59
  std = emb_init_std
60
  if std == 0:
61
- warnings.warn(f'Embedding layer initialized to 0.')
62
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
63
  if verbose > 1:
64
- warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
 
 
65
  elif emb_init_uniform_lim is not None:
66
  lim = emb_init_uniform_lim
67
  if isinstance(lim, Sequence):
68
  if len(lim) > 2:
69
- raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
 
 
70
  if lim[0] == lim[1]:
71
- warnings.warn(f'Embedding layer initialized to {lim[0]}.')
72
  else:
73
  if lim == 0:
74
- warnings.warn(f'Embedding layer initialized to 0.')
75
  lim = [-lim, lim]
76
  (a, b) = lim
77
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
78
  if verbose > 1:
79
- warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
 
 
80
  else:
81
  emb_init_fn_ = init_fn_
82
  emb_init_fn_(module.weight)
83
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
84
  if verbose > 1:
85
- warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
86
- if hasattr(module, 'weight') and module.weight is not None:
 
 
87
  torch.nn.init.ones_(module.weight)
88
- if hasattr(module, 'bias') and module.bias is not None:
89
  torch.nn.init.zeros_(module.bias)
90
  elif isinstance(module, nn.MultiheadAttention):
91
  if module._qkv_same_embed_dim:
92
  assert module.in_proj_weight is not None
93
- assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
 
 
 
 
94
  assert d_model is not None
95
  _d = d_model
96
  splits = (0, _d, 2 * _d, 3 * _d)
97
- for (s, e) in zip(splits[:-1], splits[1:]):
98
  init_fn_(module.in_proj_weight[s:e])
99
  else:
100
- assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
 
 
 
 
101
  assert module.in_proj_weight is None
102
  init_fn_(module.q_proj_weight)
103
  init_fn_(module.k_proj_weight)
@@ -109,37 +145,112 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
109
  if module.bias_v is not None:
110
  torch.nn.init.zeros_(module.bias_v)
111
  init_fn_(module.out_proj.weight)
112
- if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
 
 
113
  with torch.no_grad():
114
  module.out_proj.weight.div_(div_is_residual)
115
  if module.out_proj.bias is not None:
116
  torch.nn.init.zeros_(module.out_proj.bias)
117
  else:
118
  for _ in module.parameters(recurse=False):
119
- raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
 
 
 
120
 
121
  def _normal_init_(std, mean=0.0):
122
  return partial(torch.nn.init.normal_, mean=mean, std=std)
123
 
124
- def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
125
  del kwargs
126
  init_fn_ = _normal_init_(std=std)
127
  if verbose > 1:
128
- warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
129
- generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
130
 
131
- def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
132
  del kwargs
133
  if init_std is None:
134
- raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
135
- _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
138
  del kwargs
139
  std = math.sqrt(2 / (5 * d_model))
140
- _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
141
 
142
- def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
143
  """From section 2.3.1 of GPT-NeoX-20B:
144
 
145
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
@@ -149,33 +260,158 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init
149
  del kwargs
150
  residual_div = n_layers / math.sqrt(10)
151
  if verbose > 1:
152
- warnings.warn(f'setting init_div_is_residual to {residual_div}')
153
- small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
154
 
155
- def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
156
  del kwargs
157
  if verbose > 1:
158
- warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
159
- kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
160
- generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  del kwargs
164
  if verbose > 1:
165
- warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
166
- kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
167
- generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
170
  del kwargs
171
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
172
  if verbose > 1:
173
- warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
174
- generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
177
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
  if verbose > 1:
179
- warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
180
- generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
181
- MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from torch import nn
8
  from .norm import NORM_CLASS_REGISTRY
9
 
10
+
11
+ def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs):
12
  del kwargs
13
  if verbose > 1:
14
  warnings.warn(f"Initializing network using module's reset_parameters attribute")
15
+ if hasattr(module, "reset_parameters"):
16
  module.reset_parameters()
17
 
18
+
19
  def fused_init_helper_(module: nn.Module, init_fn_):
20
+ _fused = getattr(module, "_fused", None)
21
  if _fused is None:
22
+ raise RuntimeError(f"Internal logic error")
23
  (dim, splits) = _fused
24
  splits = (0, *splits, module.weight.size(dim))
25
+ for s, e in zip(splits[:-1], splits[1:]):
26
  slice_indices = [slice(None)] * module.weight.ndim
27
  slice_indices[dim] = slice(s, e)
28
  init_fn_(module.weight[slice_indices])
29
 
30
+
31
+ def generic_param_init_fn_(
32
+ module: nn.Module,
33
+ init_fn_,
34
+ n_layers: int,
35
+ d_model: Optional[int] = None,
36
+ init_div_is_residual: Union[int, float, str, bool] = True,
37
+ emb_init_std: Optional[float] = None,
38
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
39
+ verbose: int = 0,
40
+ **kwargs,
41
+ ):
42
  del kwargs
43
  if verbose > 1:
44
+ warnings.warn(f"If model has bias parameters they are initialized to 0.")
45
  init_div_is_residual = init_div_is_residual
46
  if init_div_is_residual is False:
47
  div_is_residual = 1.0
48
  elif init_div_is_residual is True:
49
  div_is_residual = math.sqrt(2 * n_layers)
50
+ elif isinstance(init_div_is_residual, float) or isinstance(
51
+ init_div_is_residual, int
52
+ ):
53
  div_is_residual = init_div_is_residual
54
  elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
55
  div_is_residual = float(init_div_is_residual)
56
  else:
57
  div_is_residual = 1.0
58
+ raise ValueError(
59
+ f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}"
60
+ )
61
  if init_div_is_residual is not False:
62
  if verbose > 1:
63
+ warnings.warn(
64
+ f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
65
+ + f"Set `init_div_is_residual: false` in init config to disable this."
66
+ )
67
  if isinstance(module, nn.Linear):
68
+ if hasattr(module, "_fused"):
69
  fused_init_helper_(module, init_fn_)
70
  else:
71
  init_fn_(module.weight)
72
  if module.bias is not None:
73
  torch.nn.init.zeros_(module.bias)
74
+ if init_div_is_residual is not False and getattr(module, "_is_residual", False):
75
  with torch.no_grad():
76
  module.weight.div_(div_is_residual)
77
  elif isinstance(module, nn.Embedding):
78
  if emb_init_std is not None:
79
  std = emb_init_std
80
  if std == 0:
81
+ warnings.warn(f"Embedding layer initialized to 0.")
82
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
83
  if verbose > 1:
84
+ warnings.warn(
85
+ f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}."
86
+ )
87
  elif emb_init_uniform_lim is not None:
88
  lim = emb_init_uniform_lim
89
  if isinstance(lim, Sequence):
90
  if len(lim) > 2:
91
+ raise ValueError(
92
+ f"Uniform init requires a min and a max limit. User input: {lim}."
93
+ )
94
  if lim[0] == lim[1]:
95
+ warnings.warn(f"Embedding layer initialized to {lim[0]}.")
96
  else:
97
  if lim == 0:
98
+ warnings.warn(f"Embedding layer initialized to 0.")
99
  lim = [-lim, lim]
100
  (a, b) = lim
101
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
102
  if verbose > 1:
103
+ warnings.warn(
104
+ f"Embedding layer initialized using uniform distribution in range {lim}."
105
+ )
106
  else:
107
  emb_init_fn_ = init_fn_
108
  emb_init_fn_(module.weight)
109
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
110
  if verbose > 1:
111
+ warnings.warn(
112
+ f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0."
113
+ )
114
+ if hasattr(module, "weight") and module.weight is not None:
115
  torch.nn.init.ones_(module.weight)
116
+ if hasattr(module, "bias") and module.bias is not None:
117
  torch.nn.init.zeros_(module.bias)
118
  elif isinstance(module, nn.MultiheadAttention):
119
  if module._qkv_same_embed_dim:
120
  assert module.in_proj_weight is not None
121
+ assert (
122
+ module.q_proj_weight is None
123
+ and module.k_proj_weight is None
124
+ and (module.v_proj_weight is None)
125
+ )
126
  assert d_model is not None
127
  _d = d_model
128
  splits = (0, _d, 2 * _d, 3 * _d)
129
+ for s, e in zip(splits[:-1], splits[1:]):
130
  init_fn_(module.in_proj_weight[s:e])
131
  else:
132
+ assert (
133
+ module.q_proj_weight is not None
134
+ and module.k_proj_weight is not None
135
+ and (module.v_proj_weight is not None)
136
+ )
137
  assert module.in_proj_weight is None
138
  init_fn_(module.q_proj_weight)
139
  init_fn_(module.k_proj_weight)
 
145
  if module.bias_v is not None:
146
  torch.nn.init.zeros_(module.bias_v)
147
  init_fn_(module.out_proj.weight)
148
+ if init_div_is_residual is not False and getattr(
149
+ module.out_proj, "_is_residual", False
150
+ ):
151
  with torch.no_grad():
152
  module.out_proj.weight.div_(div_is_residual)
153
  if module.out_proj.bias is not None:
154
  torch.nn.init.zeros_(module.out_proj.bias)
155
  else:
156
  for _ in module.parameters(recurse=False):
157
+ raise NotImplementedError(
158
+ f"{module.__class__.__name__} parameters are not initialized by param_init_fn."
159
+ )
160
+
161
 
162
  def _normal_init_(std, mean=0.0):
163
  return partial(torch.nn.init.normal_, mean=mean, std=std)
164
 
165
+
166
+ def _normal_param_init_fn_(
167
+ module: nn.Module,
168
+ std: float,
169
+ n_layers: int,
170
+ d_model: Optional[int] = None,
171
+ init_div_is_residual: Union[int, float, str, bool] = True,
172
+ emb_init_std: Optional[float] = None,
173
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
174
+ verbose: int = 0,
175
+ **kwargs,
176
+ ):
177
  del kwargs
178
  init_fn_ = _normal_init_(std=std)
179
  if verbose > 1:
180
+ warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}")
181
+ generic_param_init_fn_(
182
+ module=module,
183
+ init_fn_=init_fn_,
184
+ d_model=d_model,
185
+ n_layers=n_layers,
186
+ init_div_is_residual=init_div_is_residual,
187
+ emb_init_std=emb_init_std,
188
+ emb_init_uniform_lim=emb_init_uniform_lim,
189
+ verbose=verbose,
190
+ )
191
+
192
 
193
+ def baseline_param_init_fn_(
194
+ module: nn.Module,
195
+ init_std: float,
196
+ n_layers: int,
197
+ d_model: Optional[int] = None,
198
+ init_div_is_residual: Union[int, float, str, bool] = True,
199
+ emb_init_std: Optional[float] = None,
200
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
201
+ verbose: int = 0,
202
+ **kwargs,
203
+ ):
204
  del kwargs
205
  if init_std is None:
206
+ raise ValueError(
207
+ "You must set model.init_config['init_std'] to a float value to use the default initialization scheme."
208
+ )
209
+ _normal_param_init_fn_(
210
+ module=module,
211
+ std=init_std,
212
+ d_model=d_model,
213
+ n_layers=n_layers,
214
+ init_div_is_residual=init_div_is_residual,
215
+ emb_init_std=emb_init_std,
216
+ emb_init_uniform_lim=emb_init_uniform_lim,
217
+ verbose=verbose,
218
+ )
219
+
220
 
221
+ def small_param_init_fn_(
222
+ module: nn.Module,
223
+ n_layers: int,
224
+ d_model: int,
225
+ init_div_is_residual: Union[int, float, str, bool] = True,
226
+ emb_init_std: Optional[float] = None,
227
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
228
+ verbose: int = 0,
229
+ **kwargs,
230
+ ):
231
  del kwargs
232
  std = math.sqrt(2 / (5 * d_model))
233
+ _normal_param_init_fn_(
234
+ module=module,
235
+ std=std,
236
+ d_model=d_model,
237
+ n_layers=n_layers,
238
+ init_div_is_residual=init_div_is_residual,
239
+ emb_init_std=emb_init_std,
240
+ emb_init_uniform_lim=emb_init_uniform_lim,
241
+ verbose=verbose,
242
+ )
243
 
244
+
245
+ def neox_param_init_fn_(
246
+ module: nn.Module,
247
+ n_layers: int,
248
+ d_model: int,
249
+ emb_init_std: Optional[float] = None,
250
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
251
+ verbose: int = 0,
252
+ **kwargs,
253
+ ):
254
  """From section 2.3.1 of GPT-NeoX-20B:
255
 
256
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
 
260
  del kwargs
261
  residual_div = n_layers / math.sqrt(10)
262
  if verbose > 1:
263
+ warnings.warn(f"setting init_div_is_residual to {residual_div}")
264
+ small_param_init_fn_(
265
+ module=module,
266
+ d_model=d_model,
267
+ n_layers=n_layers,
268
+ init_div_is_residual=residual_div,
269
+ emb_init_std=emb_init_std,
270
+ emb_init_uniform_lim=emb_init_uniform_lim,
271
+ verbose=verbose,
272
+ )
273
+
274
 
275
+ def kaiming_uniform_param_init_fn_(
276
+ module: nn.Module,
277
+ n_layers: int,
278
+ d_model: Optional[int] = None,
279
+ init_div_is_residual: Union[int, float, str, bool] = True,
280
+ emb_init_std: Optional[float] = None,
281
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
282
+ init_gain: float = 0,
283
+ fan_mode: str = "fan_in",
284
+ init_nonlinearity: str = "leaky_relu",
285
+ verbose: int = 0,
286
+ **kwargs,
287
+ ):
288
  del kwargs
289
  if verbose > 1:
290
+ warnings.warn(
291
+ f"Using nn.init.kaiming_uniform_ init fn with parameters: "
292
+ + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
293
+ )
294
+ kaiming_uniform_ = partial(
295
+ nn.init.kaiming_uniform_,
296
+ a=init_gain,
297
+ mode=fan_mode,
298
+ nonlinearity=init_nonlinearity,
299
+ )
300
+ generic_param_init_fn_(
301
+ module=module,
302
+ init_fn_=kaiming_uniform_,
303
+ d_model=d_model,
304
+ n_layers=n_layers,
305
+ init_div_is_residual=init_div_is_residual,
306
+ emb_init_std=emb_init_std,
307
+ emb_init_uniform_lim=emb_init_uniform_lim,
308
+ verbose=verbose,
309
+ )
310
 
311
+
312
+ def kaiming_normal_param_init_fn_(
313
+ module: nn.Module,
314
+ n_layers: int,
315
+ d_model: Optional[int] = None,
316
+ init_div_is_residual: Union[int, float, str, bool] = True,
317
+ emb_init_std: Optional[float] = None,
318
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
319
+ init_gain: float = 0,
320
+ fan_mode: str = "fan_in",
321
+ init_nonlinearity: str = "leaky_relu",
322
+ verbose: int = 0,
323
+ **kwargs,
324
+ ):
325
  del kwargs
326
  if verbose > 1:
327
+ warnings.warn(
328
+ f"Using nn.init.kaiming_normal_ init fn with parameters: "
329
+ + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
330
+ )
331
+ kaiming_normal_ = partial(
332
+ torch.nn.init.kaiming_normal_,
333
+ a=init_gain,
334
+ mode=fan_mode,
335
+ nonlinearity=init_nonlinearity,
336
+ )
337
+ generic_param_init_fn_(
338
+ module=module,
339
+ init_fn_=kaiming_normal_,
340
+ d_model=d_model,
341
+ n_layers=n_layers,
342
+ init_div_is_residual=init_div_is_residual,
343
+ emb_init_std=emb_init_std,
344
+ emb_init_uniform_lim=emb_init_uniform_lim,
345
+ verbose=verbose,
346
+ )
347
+
348
 
349
+ def xavier_uniform_param_init_fn_(
350
+ module: nn.Module,
351
+ n_layers: int,
352
+ d_model: Optional[int] = None,
353
+ init_div_is_residual: Union[int, float, str, bool] = True,
354
+ emb_init_std: Optional[float] = None,
355
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
356
+ init_gain: float = 0,
357
+ verbose: int = 0,
358
+ **kwargs,
359
+ ):
360
  del kwargs
361
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
362
  if verbose > 1:
363
+ warnings.warn(
364
+ f"Using torch.nn.init.xavier_uniform_ init fn with parameters: "
365
+ + f"gain={init_gain}"
366
+ )
367
+ generic_param_init_fn_(
368
+ module=module,
369
+ init_fn_=xavier_uniform_,
370
+ d_model=d_model,
371
+ n_layers=n_layers,
372
+ init_div_is_residual=init_div_is_residual,
373
+ emb_init_std=emb_init_std,
374
+ emb_init_uniform_lim=emb_init_uniform_lim,
375
+ verbose=verbose,
376
+ )
377
 
378
+
379
+ def xavier_normal_param_init_fn_(
380
+ module: nn.Module,
381
+ n_layers: int,
382
+ d_model: Optional[int] = None,
383
+ init_div_is_residual: Union[int, float, str, bool] = True,
384
+ emb_init_std: Optional[float] = None,
385
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
386
+ init_gain: float = 0,
387
+ verbose: int = 0,
388
+ **kwargs,
389
+ ):
390
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
391
  if verbose > 1:
392
+ warnings.warn(
393
+ f"Using torch.nn.init.xavier_normal_ init fn with parameters: "
394
+ + f"gain={init_gain}"
395
+ )
396
+ generic_param_init_fn_(
397
+ module=module,
398
+ init_fn_=xavier_normal_,
399
+ d_model=d_model,
400
+ n_layers=n_layers,
401
+ init_div_is_residual=init_div_is_residual,
402
+ emb_init_std=emb_init_std,
403
+ emb_init_uniform_lim=emb_init_uniform_lim,
404
+ verbose=verbose,
405
+ )
406
+
407
+
408
+ MODEL_INIT_REGISTRY = {
409
+ "default_": torch_default_param_init_fn_,
410
+ "baseline_": baseline_param_init_fn_,
411
+ "kaiming_uniform_": kaiming_uniform_param_init_fn_,
412
+ "kaiming_normal_": kaiming_normal_param_init_fn_,
413
+ "neox_init_": neox_param_init_fn_,
414
+ "small_init_": small_param_init_fn_,
415
+ "xavier_uniform_": xavier_uniform_param_init_fn_,
416
+ "xavier_normal_": xavier_normal_param_init_fn_,
417
+ }