pirroh madhavatreplit commited on
Commit
ce1f658
1 Parent(s): 1076fcf

Add files for release (#1)

Browse files

- Add files for release (9ed598fad7885d062e25157bbd3f549afd0964cc)


Co-authored-by: Madhav <madhavatreplit@users.noreply.huggingface.co>

README.md CHANGED
@@ -1,3 +1,104 @@
1
  ---
2
  license: cc-by-sa-4.0
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-sa-4.0
3
+ datasets:
4
+ - bigcode/the-stack-dedup
5
  ---
6
+
7
+
8
+ # replit-code-v1-3b
9
+
10
+ `replit-code-v1-3b` is a 2.7B model. It is trained on the Stack Dedup v1.2 dataset.
11
+
12
+
13
+
14
+ ## Model
15
+
16
+
17
+ ```python
18
+ from transformers import AutoModelForCausalLM
19
+
20
+ # load model
21
+ model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
22
+ ```
23
+
24
+ To use the optimized Triton implementation of FlashAttention on GPUs with BF16 precision, move the model to `bfloat16` and use it as follows:
25
+
26
+ ```python
27
+ from transformers import AutoModelForCausalLM
28
+
29
+ # load model
30
+ model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True, attn_impl='triton')
31
+ model.to(device='cuda:0', dtype=torch.bfloat16)
32
+
33
+ # forward pass
34
+ x = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
35
+ x = x.to(device='cuda:0', dtype=torch.bfloat16)
36
+ y = model(x)
37
+
38
+ ```
39
+
40
+ Note that `trust_remote_code=True` is passed to the `from_pretrained` method because ReplitLM is not a class in the
41
+ [Transformers](https://huggingface.co/docs/transformers/index) library.
42
+
43
+ ## Tokenizer
44
+
45
+ We have trained a custom SentencePiece Unigram tokenizer optimized with a vocabulary specifically for code of 32768 tokens.
46
+
47
+ Note that using this requires the `sentencepiece` library to be installed.
48
+
49
+ The tokenizer can be used as follows:
50
+
51
+ ```python
52
+ from transformers import AutoTokenizer
53
+
54
+ # load tokenizer
55
+ tokenizer = AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
56
+
57
+ # single input encoding + generation
58
+ x = tokenizer.encode('def hello():\n print("hello world")\n', return_tensors='pt')
59
+ y = model.generate(x)
60
+
61
+ # decoding, clean_up_tokenization_spaces=False to ensure syntactical correctness
62
+ generated_code = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
63
+ print(generated_code)
64
+ ```
65
+
66
+ Note that:
67
+ - `trust_remote_code=True` is passed to the `from_pretrained` method because ReplitLM is not a class in the [Transformers](https://huggingface.co/docs/transformers/index) library.
68
+ - `clean_up_tokenization_spaces=False` is meant to avoid removing spaces in the output, because that would affect the syntactical correctness of the generated code.
69
+
70
+
71
+ ## Generation
72
+
73
+ You can generate code using the `transformers` library as follows:
74
+
75
+ ```python
76
+ tokenizer = transformers.AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
77
+ model = transformers.AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
78
+
79
+ x = tokenizer.encode('def fibonacci(n): ', return_tensors='pt')
80
+ y = model.generate(x, max_length=100, do_sample=True, top_p=0.95, top_k=4, temperature=0.2, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
81
+
82
+ # decoding, clean_up_tokenization_spaces=False to ensure syntactical correctness
83
+ generated_code = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
84
+ print(generated_code)
85
+ ```
86
+
87
+ Experiment with different decoding methods and parameters to get the best results for your use case.
88
+
89
+ ## Post Processing
90
+
91
+ Note that as with all code generation models, post-processing of the generated code is important. In particular, the following post-processing steps are recommended:
92
+ - stop generation when the EOS token is encountered
93
+ - remove trailing whitespaces
94
+ - set `max_tokens` to a reasonable value based on your completion use case
95
+ - truncate generation to stop words such as `return`, `def`, "```", "`\n\n\n`" to avoid generating incomplete code when `max_tokens` is larger than the length of the expected generated code.
96
+
97
+ ## Inference
98
+ Coming soon.
99
+
100
+ ## Evaluation
101
+ Coming soon.
102
+
103
+ ## Model Hash
104
+ 5bc28ce32c6f9aec935ead7b60ea1c46
attention.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Attention layers."""
5
+
6
+ import math
7
+ import warnings
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from einops import rearrange
12
+ from torch import nn
13
+
14
+ from .low_precision_layernorm import LPLayerNorm
15
+
16
+
17
+ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
18
+ original_is_causal: bool):
19
+ if original_is_causal and num_query_tokens != num_key_tokens:
20
+ if num_query_tokens != 1:
21
+ raise NotImplementedError(
22
+ 'ReplitLM does not support query and key with different number of tokens, unless number of query tokens is 1.'
23
+ )
24
+ else:
25
+ return False
26
+ return original_is_causal
27
+
28
+
29
+ def scaled_multihead_dot_product_attention(
30
+ query,
31
+ key,
32
+ value,
33
+ n_heads,
34
+ softmax_scale=None,
35
+ attn_bias=None,
36
+ key_padding_mask=None,
37
+ is_causal=False,
38
+ dropout_p=0.0,
39
+ training=False,
40
+ needs_weights=False,
41
+ ):
42
+
43
+ q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
44
+ k = rearrange(key, 'b s (h d) -> b h d s', h=n_heads) # includes key.t()
45
+ v = rearrange(value, 'b s (h d) -> b h s d', h=n_heads)
46
+
47
+ min_val = torch.finfo(q.dtype).min
48
+
49
+ b, _, s_q, d = q.shape
50
+ s_k = k.size(-1)
51
+
52
+ if softmax_scale is None:
53
+ softmax_scale = 1 / math.sqrt(d)
54
+
55
+ attn_weight = q.matmul(k) * softmax_scale
56
+
57
+ if attn_bias is not None:
58
+ if (attn_bias.size(-1) != 1 and
59
+ attn_bias.size(-1) != s_k) or (attn_bias.size(-2) != 1 and
60
+ attn_bias.size(-2) != s_q):
61
+ raise RuntimeError(
62
+ f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.'
63
+ )
64
+ attn_weight = attn_weight + attn_bias
65
+
66
+ if key_padding_mask is not None:
67
+ if attn_bias is not None:
68
+ warnings.warn(
69
+ 'Propogating key_padding_mask to the attention module ' +
70
+ 'and applying it within the attention module can cause ' +
71
+ 'unneccessary computation/memory usage. Consider integrating ' +
72
+ 'into attn_bias once and passing that to each attention ' +
73
+ 'module instead.'
74
+ )
75
+ attn_weight = attn_weight.masked_fill(
76
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val)
77
+
78
+ if is_causal:
79
+ s = max(s_q, s_k)
80
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
81
+ causal_mask = causal_mask.tril()
82
+ causal_mask = causal_mask.to(torch.bool)
83
+ causal_mask = ~causal_mask
84
+ causal_mask = causal_mask[-s_q:, -s_k:]
85
+ attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
86
+ min_val)
87
+
88
+ attn_weight = torch.softmax(attn_weight, dim=-1)
89
+
90
+ if dropout_p:
91
+ attn_weight = torch.nn.functional.dropout(attn_weight,
92
+ p=dropout_p,
93
+ training=training,
94
+ inplace=True)
95
+
96
+ out = attn_weight.matmul(v)
97
+ out = rearrange(out, 'b h s d -> b s (h d)')
98
+
99
+ if needs_weights:
100
+ return out, attn_weight
101
+ return out, None
102
+
103
+
104
+ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
105
+ for tensor in tensors:
106
+ if tensor.dtype not in valid_dtypes:
107
+ raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.')
108
+ if not tensor.is_cuda:
109
+ raise TypeError(
110
+ f'Inputs must be cuda tensors ({tensor.is_cuda=}).')
111
+
112
+
113
+ def flash_attn_fn(
114
+ query,
115
+ key,
116
+ value,
117
+ n_heads,
118
+ softmax_scale=None,
119
+ attn_bias=None,
120
+ key_padding_mask=None,
121
+ is_causal=False,
122
+ dropout_p=0.0,
123
+ training=False,
124
+ needs_weights=False,
125
+ ):
126
+ try:
127
+ from flash_attn import bert_padding, flash_attn_interface
128
+ except:
129
+ raise RuntimeError('Please install flash_attn==0.2.8')
130
+
131
+ check_valid_inputs(query, key, value)
132
+
133
+ if attn_bias is not None:
134
+ raise NotImplementedError(f'attn_bias not implemented for flash attn.')
135
+
136
+ batch_size, seqlen = query.shape[:2]
137
+
138
+ if key_padding_mask is None:
139
+ key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
140
+ query_padding_mask = key_padding_mask[:, -query.size(1):]
141
+
142
+ query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
143
+ query, query_padding_mask)
144
+ query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
145
+
146
+ key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
147
+ key, key_padding_mask)
148
+ key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
149
+
150
+ value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
151
+ value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
152
+
153
+ dropout_p = dropout_p if training else 0.0
154
+
155
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
156
+
157
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
158
+ query_unpad,
159
+ key_unpad,
160
+ value_unpad,
161
+ cu_seqlens_q,
162
+ cu_seqlens_k,
163
+ max_seqlen_q,
164
+ max_seqlen_k,
165
+ dropout_p,
166
+ softmax_scale=softmax_scale,
167
+ causal=reset_is_causal,
168
+ return_attn_probs=needs_weights)
169
+
170
+ output = bert_padding.pad_input(
171
+ rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
172
+ seqlen)
173
+ return output, None
174
+
175
+
176
+ def triton_flash_attn_fn(
177
+ query,
178
+ key,
179
+ value,
180
+ n_heads,
181
+ softmax_scale=None,
182
+ attn_bias=None,
183
+ key_padding_mask=None,
184
+ is_causal=False,
185
+ dropout_p=0.0,
186
+ training=False,
187
+ needs_weights=False,
188
+ ):
189
+ try:
190
+ from flash_attn import flash_attn_triton # type: ignore
191
+ except:
192
+ raise RuntimeError(
193
+ 'Please install flash_attn==0.2.8 and triton==2.0.0.dev20221202.')
194
+
195
+ check_valid_inputs(query, key, value)
196
+
197
+ if dropout_p:
198
+ raise NotImplementedError(
199
+ f'Dropout not implemented for attn_impl: triton.')
200
+
201
+ if needs_weights:
202
+ raise NotImplementedError(
203
+ f'attn_impl: triton cannot return attn weights.')
204
+
205
+ if key_padding_mask is not None:
206
+ warnings.warn(
207
+ 'Propagating key_padding_mask to the attention module ' +
208
+ 'and applying it within the attention module can cause ' +
209
+ 'unnecessary computation/memory usage. Consider integrating ' +
210
+ 'into attn_bias once and passing that to each attention ' +
211
+ 'module instead.'
212
+ )
213
+ b_size, s_k = key_padding_mask.shape[:2]
214
+
215
+ if attn_bias is None:
216
+ attn_bias = query.new_zeros(b_size, 1, 1, s_k)
217
+
218
+ attn_bias = attn_bias.masked_fill(
219
+ ~key_padding_mask.view((b_size, 1, 1, s_k)),
220
+ torch.finfo(query.dtype).min)
221
+
222
+ query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
223
+ key = rearrange(key, 'b s (h d) -> b s h d', h=n_heads)
224
+ value = rearrange(value, 'b s (h d) -> b s h d', h=n_heads)
225
+
226
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
227
+ attn_output = flash_attn_triton.flash_attn_func(query, key, value,
228
+ attn_bias, reset_is_causal,
229
+ softmax_scale)
230
+
231
+ output = attn_output.view(*attn_output.shape[:2], -1)
232
+
233
+ return output, None
234
+
235
+
236
+ class MultiheadAttention(nn.Module):
237
+ """Multi-head self attention.
238
+
239
+ Using torch or triton attention implemetation enables user to also use
240
+ additive bias.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ d_model: int,
246
+ n_heads: int,
247
+ attn_impl: str = 'triton',
248
+ attn_clip_qkv: Optional[float] = None,
249
+ attn_qk_ln: bool = False,
250
+ softmax_scale: Optional[float] = None,
251
+ attn_pdrop: float = 0.0,
252
+ low_precision_layernorm: bool = False,
253
+ device: Optional[str] = None,
254
+ ):
255
+ super().__init__()
256
+
257
+ self.attn_impl = attn_impl
258
+ self.clip_qkv = attn_clip_qkv
259
+ self.attn_qk_ln = attn_qk_ln
260
+
261
+ self.d_model = d_model
262
+ self.n_heads = n_heads
263
+ self.softmax_scale = softmax_scale
264
+ if self.softmax_scale is None:
265
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
266
+ self.attn_dropout_p = attn_pdrop
267
+
268
+ self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
269
+ # for param init fn; enables shape based init of fused layers
270
+ fuse_splits = (d_model, 2 * d_model)
271
+ self.Wqkv._fused = (0, fuse_splits) # type: ignore
272
+
273
+ if self.attn_qk_ln:
274
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
275
+ self.q_ln = layernorm_class(self.d_model, device=device)
276
+ self.k_ln = layernorm_class(self.d_model, device=device)
277
+
278
+ if self.attn_impl == 'flash':
279
+ self.attn_fn = flash_attn_fn
280
+ elif self.attn_impl == 'triton':
281
+ self.attn_fn = triton_flash_attn_fn
282
+ warnings.warn(
283
+ 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +
284
+ 'it uses more memory. When training larger models this can trigger ' +
285
+ 'alloc retries which hurts performance. If encountered, we recommend ' +
286
+ 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
287
+ elif self.attn_impl == 'torch':
288
+ self.attn_fn = scaled_multihead_dot_product_attention
289
+ if torch.cuda.is_available():
290
+ warnings.warn(
291
+ 'Using `attn_impl: torch`. If your model does not use `alibi` or ' +
292
+ '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +
293
+ 'we recommend using `attn_impl: triton`.'
294
+ )
295
+ else:
296
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
297
+
298
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
299
+ self.out_proj._is_residual = True # type: ignore
300
+
301
+ def forward(self,
302
+ x,
303
+ past_key_value=None,
304
+ attn_bias=None,
305
+ attention_mask=None,
306
+ is_causal=True,
307
+ needs_weights=False):
308
+ qkv = self.Wqkv(x)
309
+
310
+ if self.clip_qkv:
311
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
312
+
313
+ query, key, value = qkv.chunk(3, dim=2)
314
+
315
+ key_padding_mask = attention_mask
316
+
317
+ if self.attn_qk_ln:
318
+ # Applying layernorm to qk
319
+ dtype = query.dtype
320
+ query = self.q_ln(query).to(dtype)
321
+ key = self.k_ln(key).to(dtype)
322
+
323
+ if past_key_value is not None:
324
+ if len(past_key_value) != 0:
325
+ key = torch.cat([past_key_value[0], key], dim=1)
326
+ value = torch.cat([past_key_value[1], value], dim=1)
327
+
328
+ past_key_value = (key, value)
329
+
330
+ if attn_bias is not None:
331
+ attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
332
+
333
+ context, attn_weights = self.attn_fn(
334
+ query,
335
+ key,
336
+ value,
337
+ self.n_heads,
338
+ softmax_scale=self.softmax_scale,
339
+ attn_bias=attn_bias,
340
+ key_padding_mask=key_padding_mask,
341
+ is_causal=is_causal,
342
+ dropout_p=self.attn_dropout_p,
343
+ training=self.training,
344
+ needs_weights=needs_weights,
345
+ )
346
+
347
+ return self.out_proj(context), attn_weights, past_key_value
348
+
349
+
350
+ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal,
351
+ use_sequence_id):
352
+ if attn_impl == 'flash':
353
+ return None
354
+ elif attn_impl in ['torch', 'triton']:
355
+ if alibi:
356
+ if (prefix_lm or not causal) or use_sequence_id:
357
+ return (1, n_heads, seq_len, seq_len)
358
+ return (1, n_heads, 1, seq_len)
359
+ elif prefix_lm or use_sequence_id:
360
+ return (1, 1, seq_len, seq_len)
361
+ return None
362
+ else:
363
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
364
+
365
+
366
+ def attn_bias(attn_impl,
367
+ attn_bias,
368
+ n_heads,
369
+ seq_len,
370
+ causal=False,
371
+ alibi=False,
372
+ alibi_bias_max=8):
373
+ if attn_impl == 'flash':
374
+ return None
375
+ elif attn_impl in ['torch', 'triton']:
376
+ if alibi:
377
+ # in place add alibi to attn bias
378
+ device, dtype = attn_bias.device, attn_bias.dtype
379
+ attn_bias = attn_bias.add(
380
+ alibi_bias(n_heads,
381
+ seq_len,
382
+ full=not causal,
383
+ alibi_bias_max=alibi_bias_max,
384
+ device=device,
385
+ dtype=dtype))
386
+ return attn_bias
387
+ else:
388
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
389
+
390
+
391
+ def alibi_bias(n_heads,
392
+ seq_len,
393
+ full=False,
394
+ alibi_bias_max=8,
395
+ device=None,
396
+ dtype=None):
397
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=dtype,
398
+ device=device).view(1, 1, 1, seq_len)
399
+ if full:
400
+ # generate 1 x Heads x SeqLen x SeqLen alibi bias mask
401
+ # otherwise the mask is 1 x Heads x 1 x SeqLen (which is broadcast to the appropriate size)
402
+ alibi_bias = alibi_bias - torch.arange(
403
+ 1 - seq_len, 1, dtype=dtype, device=device).view(1, 1, seq_len, 1)
404
+ alibi_bias = alibi_bias.abs().mul(-1)
405
+
406
+ m = torch.arange(1, n_heads + 1, dtype=dtype, device=device)
407
+ m = m.mul(alibi_bias_max / n_heads)
408
+ alibi_bias = alibi_bias * (1. / (2**m.view(1, n_heads, 1, 1)))
409
+ return alibi_bias
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "replit/replit-code-v1-3b",
3
+ "alibi": true,
4
+ "alibi_bias_max": 8,
5
+ "architectures": [
6
+ "ReplitLM"
7
+ ],
8
+ "attn_clip_qkv": null,
9
+ "attn_impl": "torch",
10
+ "attn_pdrop": 0,
11
+ "attn_qk_ln": false,
12
+ "attn_uses_sequence_id": false,
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_replit_lm.ReplitLMConfig",
15
+ "AutoModelForCausalLM": "replit_lm.ReplitLM"
16
+ },
17
+ "d_model": 2560,
18
+ "emb_init_std": null,
19
+ "emb_init_uniform_lim": null,
20
+ "emb_pdrop": 0,
21
+ "embedding_fraction": 1.0,
22
+ "fan_mode": "fan_in",
23
+ "init_device": "cpu",
24
+ "init_div_is_residual": true,
25
+ "init_gain": 0,
26
+ "init_nonlinearity": "relu",
27
+ "init_std": 0.02,
28
+ "logit_scale": null,
29
+ "low_precision_layernorm": true,
30
+ "max_seq_len": 2048,
31
+ "mlp_ratio": 4,
32
+ "model_type": "replit_lm",
33
+ "n_heads": 32,
34
+ "n_layers": 32,
35
+ "no_bias": true,
36
+ "param_init_fn": "kaiming_normal_",
37
+ "prefix_lm": false,
38
+ "resid_pdrop": 0,
39
+ "softmax_scale": null,
40
+ "tokenizer_name": "replit/replit-code-v1-3b",
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.26.1",
43
+ "use_cache": false,
44
+ "verbose": 0,
45
+ "vocab_size": 32768
46
+ }
configuration_replit_lm.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Forked for ReplitLM"""
5
+
6
+ """A HuggingFace-style model configuration."""
7
+
8
+
9
+ from typing import Optional, Tuple, Union
10
+ from transformers import PretrainedConfig
11
+ class ReplitLMConfig(PretrainedConfig):
12
+ model_type = 'replit_lm'
13
+
14
+ def __init__(
15
+ self,
16
+ d_model: int = 2048,
17
+ n_heads: int = 16,
18
+ n_layers: int = 24,
19
+ mlp_ratio: int = 4,
20
+ max_seq_len: int = 2048,
21
+ vocab_size: int = 50368,
22
+ attn_pdrop: float = 0.0,
23
+ resid_pdrop: float = 0.0,
24
+ emb_pdrop: float = 0.0,
25
+ attn_impl: str = 'triton',
26
+ attn_qk_ln: bool = False,
27
+ attn_clip_qkv: Optional[float] = None,
28
+ softmax_scale: Optional[float] = None,
29
+ prefix_lm: Optional[bool] = False,
30
+ attn_uses_sequence_id: Optional[bool] = False,
31
+ alibi: bool = False,
32
+ alibi_bias_max: int = 8,
33
+ init_device: str = 'cpu',
34
+ logit_scale: Optional[Union[float, str]] = None,
35
+ no_bias: bool = False,
36
+ verbose: int = 0,
37
+ param_init_fn: str = 'kaiming_normal_',
38
+ init_div_is_residual: Union[int, float, str, bool] = True,
39
+ init_std: float = 0.02,
40
+ emb_init_std: Optional[float] = None,
41
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float],
42
+ float]] = None,
43
+ init_gain: float = 0,
44
+ fan_mode: str = 'fan_in',
45
+ init_nonlinearity: str = 'relu',
46
+ embedding_fraction: float = 1.0,
47
+ low_precision_layernorm: bool = True,
48
+ use_cache: bool = False,
49
+ **kwargs,
50
+ ):
51
+ """The ReplitLM configuration class.
52
+
53
+ Args:
54
+ d_model (int): The size of the embedding dimension of the model.
55
+ n_heads (int): The number of attention heads.
56
+ n_layers (int): The number of layers in the model.
57
+ mlp_ratio (int): The ratio of the up/down scale in the MLP.
58
+ max_seq_len (int): The maximum sequence length of the model.
59
+ vocab_size (int): The size of the vocabulary.
60
+ attn_pdrop (float): The dropout probability for the attention layers.
61
+ resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
62
+ emb_pdrop (float): The dropout probability for the embedding layer.
63
+ attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
64
+ attn_qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
65
+ attn_clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
66
+ this value.
67
+ softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
68
+ use the default scale of ``1/sqrt(d_keys)``.
69
+ prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
70
+ extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
71
+ can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
72
+ attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
73
+ When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
74
+ which sub-sequence each token belongs to.
75
+ Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
76
+ alibi (bool): Whether to use the alibi bias instead of position embeddings.
77
+ alibi_bias_max (int): The maximum value of the alibi bias.
78
+ init_device (str): The device to use for parameter initialization.
79
+ logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
80
+ no_bias (bool): Whether to use bias in all layers.
81
+ verbose (int): The verbosity level. 0 is silent.
82
+ param_init_fn (str): The parameter initialization scheme to use. One of 'default_', 'baseline_', 'kaiming_uniform_',
83
+ 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 'xavier_normal_'.
84
+ init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
85
+ init_std (float): The standard deviation of the normal distribution used to initialize the model,
86
+ if using the baseline_ parameter initialization scheme.
87
+ emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
88
+ emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
89
+ used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
90
+ init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
91
+ fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
92
+ init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
93
+ embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
94
+ low_precision_layernorm (bool): Whether to use low precision layer normalization.
95
+ use_cache (bool): Whether or not the model should return the last key/values attentions
96
+ """
97
+ self.d_model = d_model
98
+ self.n_heads = n_heads
99
+ self.n_layers = n_layers
100
+ self.mlp_ratio = mlp_ratio
101
+ self.max_seq_len = max_seq_len
102
+ self.vocab_size = vocab_size
103
+ self.attn_pdrop = attn_pdrop
104
+ self.resid_pdrop = resid_pdrop
105
+ self.emb_pdrop = emb_pdrop
106
+ self.attn_impl = attn_impl
107
+ self.attn_qk_ln = attn_qk_ln
108
+ self.attn_clip_qkv = attn_clip_qkv
109
+ self.softmax_scale = softmax_scale
110
+ self.prefix_lm = prefix_lm
111
+ self.attn_uses_sequence_id = attn_uses_sequence_id
112
+ self.alibi = alibi
113
+ self.alibi_bias_max = alibi_bias_max
114
+ self.init_device = init_device
115
+ self.logit_scale = logit_scale
116
+ self.no_bias = no_bias
117
+ self.verbose = verbose
118
+ self.param_init_fn = param_init_fn
119
+ self.init_div_is_residual = init_div_is_residual
120
+ self.init_std = init_std
121
+ self.emb_init_std = emb_init_std
122
+ self.emb_init_uniform_lim = emb_init_uniform_lim
123
+ self.init_std = init_std
124
+ self.init_gain = init_gain
125
+ self.fan_mode = fan_mode
126
+ self.init_nonlinearity = init_nonlinearity
127
+ self.embedding_fraction = embedding_fraction
128
+ self.low_precision_layernorm = low_precision_layernorm
129
+ self.use_cache = use_cache
130
+ if 'name' in kwargs:
131
+ del kwargs['name']
132
+ if 'loss_fn' in kwargs:
133
+ del kwargs['loss_fn']
134
+ super().__init__(**kwargs)
135
+
136
+ self._validate_config()
137
+
138
+ def _validate_config(self):
139
+ if self.d_model % self.n_heads != 0:
140
+ raise ValueError('d_model must be divisible by n_heads')
141
+ if any(prob < 0 or prob > 1
142
+ for prob in [self.attn_pdrop, self.resid_pdrop, self.emb_pdrop]):
143
+ raise ValueError(
144
+ 'attn_pdrop, resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1'
145
+ )
146
+ if self.attn_impl not in ['torch', 'flash', 'triton']:
147
+ raise ValueError(f'Unknown attn_impl={self.attn_impl}')
148
+ if self.prefix_lm and self.attn_impl not in ['torch', 'triton']:
149
+ raise NotImplementedError(
150
+ 'prefix_lm only implemented with torch and triton attention.')
151
+ if self.alibi and self.attn_impl not in ['torch', 'triton']:
152
+ raise NotImplementedError(
153
+ 'alibi only implemented with torch and triton attention.')
154
+ if self.attn_uses_sequence_id and self.attn_impl not in [
155
+ 'torch', 'triton'
156
+ ]:
157
+ raise NotImplementedError(
158
+ 'attn_uses_sequence_id only implemented with torch and triton attention.'
159
+ )
160
+ if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
161
+ raise ValueError(
162
+ 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
163
+ )
164
+ if isinstance(self.logit_scale,
165
+ str) and self.logit_scale != 'inv_sqrt_d_model':
166
+ raise ValueError(
167
+ f"{self.logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
168
+ )
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.26.1",
4
+ "use_cache": false
5
+ }
gpt_blocks.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """GPT Blocks used for the GPT Model."""
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from .attention import MultiheadAttention
12
+ from .low_precision_layernorm import LPLayerNorm
13
+
14
+
15
+ class GPTMLP(nn.Module):
16
+
17
+ def __init__(self,
18
+ d_model: int,
19
+ mlp_ratio: int,
20
+ device: Optional[str] = None):
21
+ super().__init__()
22
+ self.mlp_up = nn.Linear(d_model, mlp_ratio * d_model, device=device)
23
+ self.mlp_act = nn.GELU(approximate='none')
24
+ self.mlp_down = nn.Linear(mlp_ratio * d_model, d_model, device=device)
25
+ self.mlp_down._is_residual = True # type: ignore
26
+
27
+ def forward(self, x):
28
+ return self.mlp_down(self.mlp_act(self.mlp_up(x)))
29
+
30
+
31
+ class GPTBlock(nn.Module):
32
+
33
+ def __init__(self,
34
+ attn_impl: str,
35
+ d_model: int,
36
+ n_heads: int,
37
+ mlp_ratio: int,
38
+ attn_clip_qkv: Optional[float] = None,
39
+ attn_qk_ln: bool = False,
40
+ softmax_scale: Optional[float] = None,
41
+ attn_pdrop: float = 0.0,
42
+ alibi: bool = False,
43
+ resid_pdrop: float = 0.0,
44
+ low_precision_layernorm: bool = False,
45
+ device: Optional[str] = None,
46
+ **kwargs):
47
+ del kwargs # unused, just to capture any extra args from the config
48
+ super().__init__()
49
+
50
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
51
+
52
+ self.ln_1 = layernorm_class(d_model, device=device)
53
+ self.attn = MultiheadAttention(
54
+ attn_impl=attn_impl,
55
+ attn_clip_qkv=attn_clip_qkv,
56
+ attn_qk_ln=attn_qk_ln,
57
+ softmax_scale=softmax_scale,
58
+ attn_pdrop=attn_pdrop,
59
+ d_model=d_model,
60
+ n_heads=n_heads,
61
+ device=device,
62
+ )
63
+ self.ln_2 = layernorm_class(d_model, device=device)
64
+ self.mlp = GPTMLP(
65
+ d_model=d_model,
66
+ mlp_ratio=mlp_ratio,
67
+ device=device,
68
+ )
69
+ self.resid_attn_dropout = nn.Dropout(resid_pdrop)
70
+ self.resid_mlp_dropout = nn.Dropout(resid_pdrop)
71
+
72
+ def forward(
73
+ self,
74
+ x: torch.Tensor,
75
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
76
+ attn_bias: Optional[torch.Tensor] = None,
77
+ attention_mask: Optional[torch.ByteTensor] = None,
78
+ is_causal: bool = True,
79
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
80
+ a = self.ln_1(x)
81
+ b, _, past_key_value = self.attn(a,
82
+ past_key_value=past_key_value,
83
+ attn_bias=attn_bias,
84
+ attention_mask=attention_mask,
85
+ is_causal=is_causal)
86
+ x = x + self.resid_attn_dropout(b)
87
+ m = self.ln_2(x)
88
+ n = self.mlp(m)
89
+ x = x + self.resid_mlp_dropout(n)
90
+ return x, past_key_value
low_precision_layernorm.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class LPLayerNorm(torch.nn.LayerNorm):
6
+ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
7
+ super().__init__(
8
+ normalized_shape=normalized_shape,
9
+ eps=eps,
10
+ elementwise_affine=elementwise_affine,
11
+ device=device,
12
+ dtype=dtype,
13
+ )
14
+
15
+ def forward(self, x):
16
+ module_device = x.device
17
+ downcast_x = _cast_if_autocast_enabled(x)
18
+ downcast_weight = _cast_if_autocast_enabled(
19
+ self.weight) if self.weight is not None else self.weight
20
+ downcast_bias = _cast_if_autocast_enabled(
21
+ self.bias) if self.bias is not None else self.bias
22
+ with torch.autocast(enabled=False, device_type=module_device.type):
23
+ return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
24
+
25
+
26
+ def _cast_if_autocast_enabled(tensor):
27
+ if torch.is_autocast_enabled():
28
+ if tensor.device.type == 'cuda':
29
+ dtype = torch.get_autocast_gpu_dtype()
30
+ elif tensor.device.type == 'cpu':
31
+ dtype = torch.get_autocast_cpu_dtype()
32
+ else:
33
+ raise NotImplementedError()
34
+ return tensor.to(dtype=dtype)
35
+ return tensor
param_init_fns.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import math
4
+ import warnings
5
+ from collections.abc import Sequence
6
+ from functools import partial
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ def torch_default_param_init_fn_(
14
+ module: nn.Module,
15
+ verbose: int = 0,
16
+ **kwargs,
17
+ ):
18
+ del kwargs # unused, just to capture any extra args from the config
19
+ if verbose > 1:
20
+ warnings.warn(
21
+ f"Initializing network using module's reset_parameters attribute")
22
+
23
+ if hasattr(module, 'reset_parameters'):
24
+ module.reset_parameters() # type: ignore
25
+
26
+
27
+ def fused_init_helper_(module: nn.Module, init_fn_):
28
+ # parameter initialization is often based on the parameters shape.
29
+ # If a layer is fused, initialization should be based on the shapes
30
+ # of the original tensor instead of the shape of the fused tensor.
31
+ # Layers which are fused should have the _fused attibute defined.
32
+ # The first element of _fused is the dimension along which the tensor is fused.
33
+ # This is followed by an iterable of split indices."
34
+
35
+ _fused = getattr(module, '_fused', None)
36
+
37
+ if _fused is None:
38
+ raise RuntimeError(f'Internal logic error')
39
+
40
+ dim, splits = _fused
41
+ splits = (0, *splits, module.weight.size(dim)) # type: ignore
42
+ for s, e in zip(splits[:-1], splits[1:]):
43
+ slice_indices = [slice(None)] * module.weight.ndim # type: ignore
44
+ slice_indices[dim] = slice(s, e)
45
+ init_fn_(module.weight[slice_indices]) # type: ignore
46
+
47
+
48
+ def generic_param_init_fn_(
49
+ module: nn.Module,
50
+ init_fn_,
51
+ n_layers: int,
52
+ d_model: Optional[int] = None,
53
+ init_div_is_residual: Union[int, float, str, bool] = True,
54
+ emb_init_std: Optional[float] = None,
55
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
56
+ verbose: int = 0,
57
+ **kwargs,
58
+ ):
59
+ del kwargs # unused, just to capture any extra args from the config
60
+ if verbose > 1:
61
+ warnings.warn(
62
+ f'If model has bias parameters they are initialized to 0.')
63
+
64
+ # enable user to divide _is_residual weights by
65
+ # a value which defaults to math.sqrt(2 * cfg.n_layers)
66
+ init_div_is_residual = init_div_is_residual
67
+
68
+ if init_div_is_residual is False:
69
+ # not used, for pyright
70
+ div_is_residual = 1.0
71
+ elif init_div_is_residual is True:
72
+ div_is_residual = math.sqrt(2 * n_layers)
73
+ elif isinstance(init_div_is_residual, float) or isinstance(
74
+ init_div_is_residual, int):
75
+ div_is_residual = init_div_is_residual
76
+ elif isinstance(init_div_is_residual,
77
+ str) and init_div_is_residual.isnumeric():
78
+ # do not trust YAML parsing to always convert numbers to numbers
79
+ div_is_residual = float(init_div_is_residual)
80
+ else:
81
+ # not used, for pyright
82
+ div_is_residual = 1.0
83
+ raise ValueError(
84
+ f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
85
+ )
86
+
87
+ if init_div_is_residual is not False:
88
+ if verbose > 1:
89
+ warnings.warn(
90
+ f'Initializing _is_residual layers then dividing them by {div_is_residual}.' +
91
+ f'set `init_div_is_residual: false` in model config to disable this.'
92
+ )
93
+
94
+ if isinstance(module, nn.Linear):
95
+ # Linear
96
+ if hasattr(module, '_fused'):
97
+ fused_init_helper_(module, init_fn_)
98
+ else:
99
+ init_fn_(module.weight)
100
+ if module.bias is not None:
101
+ torch.nn.init.zeros_(module.bias)
102
+
103
+ if init_div_is_residual is not False and getattr(
104
+ module, '_is_residual', False):
105
+ with torch.no_grad():
106
+ module.weight.div_(div_is_residual)
107
+
108
+ elif isinstance(module, nn.Embedding):
109
+ # Embedding
110
+ if emb_init_std is not None:
111
+ std = emb_init_std
112
+ if std == 0:
113
+ warnings.warn(f'Embedding layer initialized to 0.')
114
+ emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
115
+ if verbose > 1:
116
+ warnings.warn(
117
+ f'Embedding layer initialized using normal distribution with mean=0 and {std=}.'
118
+ )
119
+ elif emb_init_uniform_lim is not None:
120
+ lim = emb_init_uniform_lim
121
+ if isinstance(lim, Sequence):
122
+ if len(lim) > 2:
123
+ raise ValueError(
124
+ f'Uniform init requires a min and a max limit. User input: {lim}.'
125
+ )
126
+ if lim[0] == lim[1]:
127
+ warnings.warn(f'Embedding layer initialized to {lim[0]}.')
128
+ else:
129
+ if lim == 0:
130
+ warnings.warn(f'Embedding layer initialized to 0.')
131
+ lim = [-lim, lim]
132
+ a, b = lim
133
+ emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
134
+ if verbose > 1:
135
+ warnings.warn(
136
+ f'Embedding layer initialized using uniform distribution in range {lim}.'
137
+ )
138
+ else:
139
+ emb_init_fn_ = init_fn_
140
+
141
+ emb_init_fn_(module.weight)
142
+
143
+ elif isinstance(module, nn.LayerNorm):
144
+ # LayerNorm
145
+ if verbose > 1:
146
+ warnings.warn(
147
+ f'LayerNorm gamma weights are set to 1. If the layer has a bias it is initialized to 0.'
148
+ )
149
+ torch.nn.init.ones_(module.weight)
150
+ if module.bias is not None:
151
+ torch.nn.init.zeros_(module.bias)
152
+
153
+ elif isinstance(module, nn.MultiheadAttention):
154
+ # torch's MultiheadAttention
155
+ if module._qkv_same_embed_dim:
156
+ assert module.in_proj_weight is not None
157
+ assert module.q_proj_weight is None and module.k_proj_weight is None and module.v_proj_weight is None
158
+ assert d_model is not None
159
+ # in_proj_weight is actually 3 layers and should be split up for width based init
160
+ _d = d_model
161
+ splits = (0, _d, 2 * _d, 3 * _d)
162
+ for s, e in zip(splits[:-1], splits[1:]):
163
+ init_fn_(module.in_proj_weight[s:e])
164
+ else:
165
+ assert module.q_proj_weight is not None and module.k_proj_weight is not None and module.v_proj_weight is not None
166
+ assert module.in_proj_weight is None
167
+ init_fn_(module.q_proj_weight)
168
+ init_fn_(module.k_proj_weight)
169
+ init_fn_(module.v_proj_weight)
170
+
171
+ # bias
172
+ if module.in_proj_bias is not None:
173
+ torch.nn.init.zeros_(module.in_proj_bias)
174
+ if module.bias_k is not None:
175
+ torch.nn.init.zeros_(module.bias_k)
176
+ if module.bias_v is not None:
177
+ torch.nn.init.zeros_(module.bias_v)
178
+
179
+ # out proj
180
+ init_fn_(module.out_proj.weight)
181
+ if init_div_is_residual is not False and getattr(
182
+ module.out_proj, '_is_residual', False):
183
+ with torch.no_grad():
184
+ module.out_proj.weight.div_(div_is_residual)
185
+ if module.out_proj.bias is not None:
186
+ torch.nn.init.zeros_(module.out_proj.bias)
187
+
188
+ else:
189
+ for _ in module.parameters(recurse=False):
190
+ # raise error if uninitialized module has any parameters
191
+ raise NotImplementedError(
192
+ f'{module.__class__.__name__} parameters are not initialized by param_init_fn.'
193
+ )
194
+
195
+
196
+ def _normal_init_(std, mean=0.0):
197
+ return partial(torch.nn.init.normal_, mean=mean, std=std)
198
+
199
+
200
+ def _normal_param_init_fn_(
201
+ module: nn.Module,
202
+ std: float,
203
+ n_layers: int,
204
+ d_model: Optional[int] = None,
205
+ init_div_is_residual: Union[int, float, str, bool] = True,
206
+ emb_init_std: Optional[float] = None,
207
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
208
+ verbose: int = 0,
209
+ **kwargs,
210
+ ):
211
+ del kwargs # unused, just to capture any extra args from the config
212
+ init_fn_ = _normal_init_(std=std)
213
+
214
+ if verbose > 1:
215
+ warnings.warn(
216
+ f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
217
+
218
+ generic_param_init_fn_(
219
+ module=module,
220
+ init_fn_=init_fn_,
221
+ d_model=d_model,
222
+ n_layers=n_layers,
223
+ init_div_is_residual=init_div_is_residual,
224
+ emb_init_std=emb_init_std,
225
+ emb_init_uniform_lim=emb_init_uniform_lim,
226
+ verbose=verbose,
227
+ )
228
+
229
+
230
+ def baseline_param_init_fn_(
231
+ module: nn.Module,
232
+ init_std: float,
233
+ n_layers: int,
234
+ d_model: Optional[int] = None,
235
+ init_div_is_residual: Union[int, float, str, bool] = True,
236
+ emb_init_std: Optional[float] = None,
237
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
238
+ verbose: int = 0,
239
+ **kwargs,
240
+ ):
241
+ del kwargs # unused, just to capture any extra args from the config
242
+ if init_std is None:
243
+ raise ValueError(
244
+ 'You must set model.init_std to a float value to use the default initialization scheme.'
245
+ )
246
+ _normal_param_init_fn_(
247
+ module=module,
248
+ std=init_std,
249
+ d_model=d_model,
250
+ n_layers=n_layers,
251
+ init_div_is_residual=init_div_is_residual,
252
+ emb_init_std=emb_init_std,
253
+ emb_init_uniform_lim=emb_init_uniform_lim,
254
+ verbose=verbose,
255
+ )
256
+
257
+
258
+ def small_param_init_fn_(
259
+ module: nn.Module,
260
+ n_layers: int,
261
+ d_model: int,
262
+ init_div_is_residual: Union[int, float, str, bool] = True,
263
+ emb_init_std: Optional[float] = None,
264
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
265
+ verbose: int = 0,
266
+ **kwargs,
267
+ ):
268
+ del kwargs # unused, just to capture any extra args from the config
269
+ # very close to kaiming normal
270
+ # from Transformers without Tears (2019) - Nguyen & Salazar
271
+ std = math.sqrt(2 / (5 * d_model))
272
+ _normal_param_init_fn_(
273
+ module=module,
274
+ std=std,
275
+ d_model=d_model,
276
+ n_layers=n_layers,
277
+ init_div_is_residual=init_div_is_residual,
278
+ emb_init_std=emb_init_std,
279
+ emb_init_uniform_lim=emb_init_uniform_lim,
280
+ verbose=verbose,
281
+ )
282
+
283
+
284
+ def neox_param_init_fn_(
285
+ module: nn.Module,
286
+ n_layers: int,
287
+ d_model: int,
288
+ emb_init_std: Optional[float] = None,
289
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
290
+ verbose: int = 0,
291
+ **kwargs,
292
+ ):
293
+ """From section 2.3.1 of GPT-NeoX-20B:
294
+
295
+ An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
296
+ see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
297
+ and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
298
+ """
299
+ del kwargs # unused, just to capture any extra args from the config
300
+ residual_div = n_layers / math.sqrt(10) # small std / wang std
301
+
302
+ if verbose > 1:
303
+ warnings.warn(f'setting init_div_is_residual to {residual_div}')
304
+
305
+ small_param_init_fn_(
306
+ module=module,
307
+ d_model=d_model,
308
+ n_layers=n_layers,
309
+ init_div_is_residual=residual_div,
310
+ emb_init_std=emb_init_std,
311
+ emb_init_uniform_lim=emb_init_uniform_lim,
312
+ verbose=verbose,
313
+ )
314
+
315
+
316
+ def kaiming_uniform_param_init_fn_(
317
+ module: nn.Module,
318
+ n_layers: int,
319
+ d_model: Optional[int] = None,
320
+ init_div_is_residual: Union[int, float, str, bool] = True,
321
+ emb_init_std: Optional[float] = None,
322
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
323
+ init_gain: float = 0,
324
+ fan_mode: str = 'fan_in',
325
+ init_nonlinearity: str = 'leaky_relu',
326
+ verbose: int = 0,
327
+ **kwargs,
328
+ ):
329
+ del kwargs # unused, just to capture any extra args from the config
330
+
331
+ if verbose > 1:
332
+ warnings.warn(
333
+ f'Using nn.init.kaiming_uniform_ init fn with parameters: ' +
334
+ f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
335
+ )
336
+
337
+ kaiming_uniform_ = partial(nn.init.kaiming_uniform_,
338
+ a=init_gain,
339
+ mode=fan_mode,
340
+ nonlinearity=init_nonlinearity)
341
+
342
+ generic_param_init_fn_(
343
+ module=module,
344
+ init_fn_=kaiming_uniform_,
345
+ d_model=d_model,
346
+ n_layers=n_layers,
347
+ init_div_is_residual=init_div_is_residual,
348
+ emb_init_std=emb_init_std,
349
+ emb_init_uniform_lim=emb_init_uniform_lim,
350
+ verbose=verbose,
351
+ )
352
+
353
+
354
+ def kaiming_normal_param_init_fn_(
355
+ module: nn.Module,
356
+ n_layers: int,
357
+ d_model: Optional[int] = None,
358
+ init_div_is_residual: Union[int, float, str, bool] = True,
359
+ emb_init_std: Optional[float] = None,
360
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
361
+ init_gain: float = 0,
362
+ fan_mode: str = 'fan_in',
363
+ init_nonlinearity: str = 'leaky_relu',
364
+ verbose: int = 0,
365
+ **kwargs,
366
+ ):
367
+ del kwargs # unused, just to capture any extra args from the config
368
+
369
+ if verbose > 1:
370
+ warnings.warn(
371
+ f'Using nn.init.kaiming_normal_ init fn with parameters: ' +
372
+ f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
373
+ )
374
+
375
+ kaiming_normal_ = partial(torch.nn.init.kaiming_normal_,
376
+ a=init_gain,
377
+ mode=fan_mode,
378
+ nonlinearity=init_nonlinearity)
379
+
380
+ generic_param_init_fn_(
381
+ module=module,
382
+ init_fn_=kaiming_normal_,
383
+ d_model=d_model,
384
+ n_layers=n_layers,
385
+ init_div_is_residual=init_div_is_residual,
386
+ emb_init_std=emb_init_std,
387
+ emb_init_uniform_lim=emb_init_uniform_lim,
388
+ verbose=verbose,
389
+ )
390
+
391
+
392
+ def xavier_uniform_param_init_fn_(
393
+ module: nn.Module,
394
+ n_layers: int,
395
+ d_model: Optional[int] = None,
396
+ init_div_is_residual: Union[int, float, str, bool] = True,
397
+ emb_init_std: Optional[float] = None,
398
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
399
+ init_gain: float = 0,
400
+ verbose: int = 0,
401
+ **kwargs,
402
+ ):
403
+ del kwargs # unused, just to capture any extra args from the config
404
+ xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
405
+
406
+ if verbose > 1:
407
+ warnings.warn(
408
+ f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' +
409
+ f'gain={init_gain}'
410
+ )
411
+
412
+ generic_param_init_fn_(
413
+ module=module,
414
+ init_fn_=xavier_uniform_,
415
+ d_model=d_model,
416
+ n_layers=n_layers,
417
+ init_div_is_residual=init_div_is_residual,
418
+ emb_init_std=emb_init_std,
419
+ emb_init_uniform_lim=emb_init_uniform_lim,
420
+ verbose=verbose,
421
+ )
422
+
423
+
424
+ def xavier_normal_param_init_fn_(
425
+ module: nn.Module,
426
+ n_layers: int,
427
+ d_model: Optional[int] = None,
428
+ init_div_is_residual: Union[int, float, str, bool] = True,
429
+ emb_init_std: Optional[float] = None,
430
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
431
+ init_gain: float = 0,
432
+ verbose: int = 0,
433
+ **kwargs,
434
+ ):
435
+ xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
436
+
437
+ if verbose > 1:
438
+ warnings.warn(
439
+ f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' +
440
+ f'gain={init_gain}'
441
+ )
442
+
443
+ generic_param_init_fn_(
444
+ module=module,
445
+ init_fn_=xavier_normal_,
446
+ d_model=d_model,
447
+ n_layers=n_layers,
448
+ init_div_is_residual=init_div_is_residual,
449
+ emb_init_std=emb_init_std,
450
+ emb_init_uniform_lim=emb_init_uniform_lim,
451
+ verbose=verbose,
452
+ )
453
+
454
+
455
+ MODEL_INIT_REGISTRY = {
456
+ 'default_': torch_default_param_init_fn_,
457
+ 'baseline_': baseline_param_init_fn_,
458
+ 'kaiming_uniform_': kaiming_uniform_param_init_fn_,
459
+ 'kaiming_normal_': kaiming_normal_param_init_fn_,
460
+ 'neox_init_': neox_param_init_fn_,
461
+ 'small_init_': small_param_init_fn_,
462
+ 'xavier_uniform_': xavier_uniform_param_init_fn_,
463
+ 'xavier_normal_': xavier_normal_param_init_fn_,
464
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6516d02ef00bc903aad7d05dc35607cff7e4c7335d4f1bf424cdcb6695cd3e86
3
+ size 10402658381
replit_lm.py ADDED
@@ -0,0 +1,453 @@