README.md CHANGED
@@ -1,3 +1,104 @@
1
  ---
2
  license: cc-by-sa-4.0
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-sa-4.0
3
+ datasets:
4
+ - bigcode/the-stack-dedup
5
  ---
6
+
7
+
8
+ # replit-code-v1-3b
9
+
10
+ `replit-code-v1-3b` is a 2.7B model. It is trained on the Stack Dedup v1.2 dataset.
11
+
12
+
13
+
14
+ ## Model
15
+
16
+
17
+ ```python
18
+ from transformers import AutoModelForCausalLM
19
+
20
+ # load model
21
+ model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
22
+ ```
23
+
24
+ To use the optimized Triton implementation of FlashAttention on GPUs with BF16 precision, move the model to `bfloat16` and use it as follows:
25
+
26
+ ```python
27
+ from transformers import AutoModelForCausalLM
28
+
29
+ # load model
30
+ model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True, attn_impl='triton')
31
+ model.to(device='cuda:0', dtype=torch.bfloat16)
32
+
33
+ # forward pass
34
+ x = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
35
+ x = x.to(device='cuda:0', dtype=torch.bfloat16)
36
+ y = model(x)
37
+
38
+ ```
39
+
40
+ Note that `trust_remote_code=True` is passed to the `from_pretrained` method because ReplitLM is not a class in the
41
+ [Transformers](https://huggingface.co/docs/transformers/index) library.
42
+
43
+ ## Tokenizer
44
+
45
+ We have trained a custom SentencePiece Unigram tokenizer optimized with a vocabulary specifically for code of 32768 tokens.
46
+
47
+ Note that using this requires the `sentencepiece` library to be installed.
48
+
49
+ The tokenizer can be used as follows:
50
+
51
+ ```python
52
+ from transformers import AutoTokenizer
53
+
54
+ # load tokenizer
55
+ tokenizer = AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
56
+
57
+ # single input encoding + generation
58
+ x = tokenizer.encode('def hello():\n print("hello world")\n', return_tensors='pt')
59
+ y = model.generate(x)
60
+
61
+ # decoding, clean_up_tokenization_spaces=False to ensure syntactical correctness
62
+ generated_code = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
63
+ print(generated_code)
64
+ ```
65
+
66
+ Note that:
67
+ - `trust_remote_code=True` is passed to the `from_pretrained` method because ReplitLM is not a class in the [Transformers](https://huggingface.co/docs/transformers/index) library.
68
+ - `clean_up_tokenization_spaces=False` is meant to avoid removing spaces in the output, because that would affect the syntactical correctness of the generated code.
69
+
70
+
71
+ ## Generation
72
+
73
+ You can generate code using the `transformers` library as follows:
74
+
75
+ ```python
76
+ tokenizer = transformers.AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
77
+ model = transformers.AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
78
+
79
+ x = tokenizer.encode('def fibonacci(n): ', return_tensors='pt')
80
+ y = model.generate(x, max_length=100, do_sample=True, top_p=0.95, top_k=4, temperature=0.2, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
81
+
82
+ # decoding, clean_up_tokenization_spaces=False to ensure syntactical correctness
83
+ generated_code = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
84
+ print(generated_code)
85
+ ```
86
+
87
+ Experiment with different decoding methods and parameters to get the best results for your use case.
88
+
89
+ ## Post Processing
90
+
91
+ Note that as with all code generation models, post-processing of the generated code is important. In particular, the following post-processing steps are recommended:
92
+ - stop generation when the EOS token is encountered
93
+ - remove trailing whitespaces
94
+ - set `max_tokens` to a reasonable value based on your completion use case
95
+ - truncate generation to stop words such as `return`, `def`, "```", "`\n\n\n`" to avoid generating incomplete code when `max_tokens` is larger than the length of the expected generated code.
96
+
97
+ ## Inference
98
+ Coming soon.
99
+
100
+ ## Evaluation
101
+ Coming soon.
102
+
103
+ ## Model Hash
104
+ 5bc28ce32c6f9aec935ead7b60ea1c46
attention.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Attention layers."""
5
+
6
+ import math
7
+ import warnings
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from einops import rearrange
12
+ from torch import nn
13
+
14
+ from .low_precision_layernorm import LPLayerNorm
15
+
16
+
17
+ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
18
+ original_is_causal: bool):
19
+ if original_is_causal and num_query_tokens != num_key_tokens:
20
+ if num_query_tokens != 1:
21
+ raise NotImplementedError(
22
+ 'ReplitLM does not support query and key with different number of tokens, unless number of query tokens is 1.'
23
+ )
24
+ else:
25
+ return False
26
+ return original_is_causal
27
+
28
+
29
+ def scaled_multihead_dot_product_attention(
30
+ query,
31
+ key,
32
+ value,
33
+ n_heads,
34
+ softmax_scale=None,
35
+ attn_bias=None,
36
+ key_padding_mask=None,
37
+ is_causal=False,
38
+ dropout_p=0.0,
39
+ training=False,
40
+ needs_weights=False,
41
+ ):
42
+
43
+ q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
44
+ k = rearrange(key, 'b s (h d) -> b h d s', h=n_heads) # includes key.t()
45
+ v = rearrange(value, 'b s (h d) -> b h s d', h=n_heads)
46
+
47
+ min_val = torch.finfo(q.dtype).min
48
+
49
+ b, _, s_q, d = q.shape
50
+ s_k = k.size(-1)
51
+
52
+ if softmax_scale is None:
53
+ softmax_scale = 1 / math.sqrt(d)
54
+
55
+ attn_weight = q.matmul(k) * softmax_scale
56
+
57
+ if attn_bias is not None:
58
+ if (attn_bias.size(-1) != 1 and
59
+ attn_bias.size(-1) != s_k) or (attn_bias.size(-2) != 1 and
60
+ attn_bias.size(-2) != s_q):
61
+ raise RuntimeError(
62
+ f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.'
63
+ )
64
+ attn_weight = attn_weight + attn_bias
65
+
66
+ if key_padding_mask is not None:
67
+ if attn_bias is not None:
68
+ warnings.warn(
69
+ 'Propogating key_padding_mask to the attention module ' +
70
+ 'and applying it within the attention module can cause ' +
71
+ 'unneccessary computation/memory usage. Consider integrating ' +
72
+ 'into attn_bias once and passing that to each attention ' +
73
+ 'module instead.'
74
+ )
75
+ attn_weight = attn_weight.masked_fill(
76
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val)
77
+
78
+ if is_causal:
79
+ s = max(s_q, s_k)
80
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
81
+ causal_mask = causal_mask.tril()
82
+ causal_mask = causal_mask.to(torch.bool)
83
+ causal_mask = ~causal_mask
84
+ causal_mask = causal_mask[-s_q:, -s_k:]
85
+ attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
86
+ min_val)
87
+
88
+ attn_weight = torch.softmax(attn_weight, dim=-1)
89
+
90
+ if dropout_p:
91
+ attn_weight = torch.nn.functional.dropout(attn_weight,
92
+ p=dropout_p,
93
+ training=training,
94
+ inplace=True)
95
+
96
+ out = attn_weight.matmul(v)
97
+ out = rearrange(out, 'b h s d -> b s (h d)')
98
+
99
+ if needs_weights:
100
+ return out, attn_weight
101
+ return out, None
102
+
103
+
104
+ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
105
+ for tensor in tensors:
106
+ if tensor.dtype not in valid_dtypes:
107
+ raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.')
108
+ if not tensor.is_cuda:
109
+ raise TypeError(
110
+ f'Inputs must be cuda tensors ({tensor.is_cuda=}).')
111
+
112
+
113
+ def flash_attn_fn(
114
+ query,
115
+ key,
116
+ value,
117
+ n_heads,
118
+ softmax_scale=None,
119
+ attn_bias=None,
120
+ key_padding_mask=None,
121
+ is_causal=False,
122
+ dropout_p=0.0,
123
+ training=False,
124
+ needs_weights=False,
125
+ ):
126
+ try:
127
+ from flash_attn import bert_padding, flash_attn_interface
128
+ except:
129
+ raise RuntimeError('Please install flash_attn==0.2.8')
130
+
131
+ check_valid_inputs(query, key, value)
132
+
133
+ if attn_bias is not None:
134
+ raise NotImplementedError(f'attn_bias not implemented for flash attn.')
135
+
136
+ batch_size, seqlen = query.shape[:2]
137
+
138
+ if key_padding_mask is None:
139
+ key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
140
+ query_padding_mask = key_padding_mask[:, -query.size(1):]
141
+
142
+ query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
143
+ query, query_padding_mask)
144
+ query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
145
+
146
+ key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
147
+ key, key_padding_mask)
148
+ key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
149
+
150
+ value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
151
+ value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
152
+
153
+ dropout_p = dropout_p if training else 0.0
154
+
155
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
156
+
157
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
158
+ query_unpad,
159
+ key_unpad,
160
+ value_unpad,
161
+ cu_seqlens_q,
162
+ cu_seqlens_k,
163
+ max_seqlen_q,
164
+ max_seqlen_k,
165
+ dropout_p,
166
+ softmax_scale=softmax_scale,
167
+ causal=reset_is_causal,
168
+ return_attn_probs=needs_weights)
169
+
170
+ output = bert_padding.pad_input(
171
+ rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
172
+ seqlen)
173
+ return output, None
174
+
175
+
176
+ def triton_flash_attn_fn(
177
+ query,
178
+ key,
179
+ value,
180
+ n_heads,
181
+ softmax_scale=None,
182
+ attn_bias=None,
183
+ key_padding_mask=None,
184
+ is_causal=False,
185
+ dropout_p=0.0,
186
+ training=False,
187
+ needs_weights=False,
188
+ ):
189
+ try:
190
+ from flash_attn import flash_attn_triton # type: ignore
191
+ except:
192
+ raise RuntimeError(
193
+ 'Please install flash_attn==0.2.8 and triton==2.0.0.dev20221202.')
194
+
195
+ check_valid_inputs(query, key, value)
196
+
197
+ if dropout_p:
198
+ raise NotImplementedError(
199
+ f'Dropout not implemented for attn_impl: triton.')
200
+
201
+ if needs_weights:
202
+ raise NotImplementedError(
203
+ f'attn_impl: triton cannot return attn weights.')
204
+
205
+ if key_padding_mask is not None:
206
+ warnings.warn(
207
+ 'Propagating key_padding_mask to the attention module ' +
208
+ 'and applying it within the attention module can cause ' +
209
+ 'unnecessary computation/memory usage. Consider integrating ' +
210
+ 'into attn_bias once and passing that to each attention ' +
211
+ 'module instead.'
212
+ )
213
+ b_size, s_k = key_padding_mask.shape[:2]
214
+
215
+ if attn_bias is None:
216
+ attn_bias = query.new_zeros(b_size, 1, 1, s_k)
217
+
218
+ attn_bias = attn_bias.masked_fill(
219
+ ~key_padding_mask.view((b_size, 1, 1, s_k)),
220
+ torch.finfo(query.dtype).min)
221
+
222
+ query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
223
+ key = rearrange(key, 'b s (h d) -> b s h d', h=n_heads)
224
+ value = rearrange(value, 'b s (h d) -> b s h d', h=n_heads)
225
+
226
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
227
+ attn_output = flash_attn_triton.flash_attn_func(query, key, value,
228
+ attn_bias, reset_is_causal,
229
+ softmax_scale)
230
+
231
+ output = attn_output.view(*attn_output.shape[:2], -1)
232
+
233
+ return output, None
234
+
235
+
236
+ class MultiheadAttention(nn.Module):
237
+ """Multi-head self attention.
238
+
239
+ Using torch or triton attention implemetation enables user to also use
240
+ additive bias.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ d_model: int,
246
+ n_heads: int,
247
+ attn_impl: str = 'triton',
248
+ attn_clip_qkv: Optional[float] = None,
249
+ attn_qk_ln: bool = False,
250
+ softmax_scale: Optional[float] = None,
251
+ attn_pdrop: float = 0.0,
252
+ low_precision_layernorm: bool = False,
253
+ device: Optional[str] = None,
254
+ ):
255
+ super().__init__()
256
+
257
+ self.attn_impl = attn_impl
258
+ self.clip_qkv = attn_clip_qkv
259
+ self.attn_qk_ln = attn_qk_ln
260
+
261
+ self.d_model = d_model
262
+ self.n_heads = n_heads
263
+ self.softmax_scale = softmax_scale
264
+ if self.softmax_scale is None:
265
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
266
+ self.attn_dropout_p = attn_pdrop
267
+
268
+ self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
269
+ # for param init fn; enables shape based init of fused layers
270
+ fuse_splits = (d_model, 2 * d_model)
271
+ self.Wqkv._fused = (0, fuse_splits) # type: ignore
272
+
273
+ if self.attn_qk_ln:
274
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
275
+ self.q_ln = layernorm_class(self.d_model, device=device)
276
+ self.k_ln = layernorm_class(self.d_model, device=device)
277
+
278
+ if self.attn_impl == 'flash':
279
+ self.attn_fn = flash_attn_fn
280
+ elif self.attn_impl == 'triton':
281
+ self.attn_fn = triton_flash_attn_fn
282
+ warnings.warn(
283
+ 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +
284
+ 'it uses more memory. When training larger models this can trigger ' +
285
+ 'alloc retries which hurts performance. If encountered, we recommend ' +
286
+ 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
287
+ elif self.attn_impl == 'torch':
288
+ self.attn_fn = scaled_multihead_dot_product_attention
289
+ if torch.cuda.is_available():
290
+ warnings.warn(
291
+ 'Using `attn_impl: torch`. If your model does not use `alibi` or ' +
292
+ '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +
293
+ 'we recommend using `attn_impl: triton`.'
294
+ )
295
+ else:
296
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
297
+
298
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
299
+ self.out_proj._is_residual = True # type: ignore
300
+
301
+ def forward(self,
302
+ x,
303
+ past_key_value=None,
304
+ attn_bias=None,
305
+ attention_mask=None,
306
+ is_causal=True,
307
+ needs_weights=False):
308
+ qkv = self.Wqkv(x)
309
+
310
+ if self.clip_qkv:
311
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
312
+
313
+ query, key, value = qkv.chunk(3, dim=2)
314
+
315
+ key_padding_mask = attention_mask
316
+
317
+ if self.attn_qk_ln:
318
+ # Applying layernorm to qk
319
+ dtype = query.dtype
320
+ query = self.q_ln(query).to(dtype)
321
+ key = self.k_ln(key).to(dtype)
322
+
323
+ if past_key_value is not None:
324
+ if len(past_key_value) != 0:
325
+ key = torch.cat([past_key_value[0], key], dim=1)
326
+ value = torch.cat([past_key_value[1], value], dim=1)
327
+
328
+ past_key_value = (key, value)
329
+
330
+ if attn_bias is not None:
331
+ attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
332
+
333
+ context, attn_weights = self.attn_fn(
334
+ query,
335
+ key,
336
+ value,
337
+ self.n_heads,
338
+ softmax_scale=self.softmax_scale,
339
+ attn_bias=attn_bias,
340
+ key_padding_mask=key_padding_mask,
341
+ is_causal=is_causal,
342
+ dropout_p=self.attn_dropout_p,
343
+ training=self.training,
344
+ needs_weights=needs_weights,
345
+ )
346
+
347
+ return self.out_proj(context), attn_weights, past_key_value
348
+
349
+
350
+ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal,
351
+ use_sequence_id):
352
+ if attn_impl == 'flash':
353
+ return None
354
+ elif attn_impl in ['torch', 'triton']:
355
+ if alibi:
356
+ if (prefix_lm or not causal) or use_sequence_id:
357
+ return (1, n_heads, seq_len, seq_len)
358
+ return (1, n_heads, 1, seq_len)
359
+ elif prefix_lm or use_sequence_id:
360
+ return (1, 1, seq_len, seq_len)
361
+ return None
362
+ else:
363
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
364
+
365
+
366
+ def attn_bias(attn_impl,
367
+ attn_bias,
368
+ n_heads,
369
+ seq_len,
370
+ causal=False,
371
+ alibi=False,
372
+ alibi_bias_max=8):
373
+ if attn_impl == 'flash':
374
+ return None
375
+ elif attn_impl in ['torch', 'triton']:
376
+ if alibi:
377
+ # in place add alibi to attn bias
378
+ device, dtype = attn_bias.device, attn_bias.dtype
379
+ attn_bias = attn_bias.add(
380
+ alibi_bias(n_heads,
381
+ seq_len,
382
+ full=not causal,
383
+ alibi_bias_max=alibi_bias_max,
384
+ device=device,
385
+ dtype=dtype))
386
+ return attn_bias
387
+ else:
388
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
389
+
390
+
391
+ def alibi_bias(n_heads,
392
+ seq_len,
393
+ full=False,
394
+ alibi_bias_max=8,
395
+ device=None,
396
+ dtype=None):
397
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=dtype,
398
+ device=device).view(1, 1, 1, seq_len)
399
+ if full:
400
+ # generate 1 x Heads x SeqLen x SeqLen alibi bias mask
401
+ # otherwise the mask is 1 x Heads x 1 x SeqLen (which is broadcast to the appropriate size)
402
+ alibi_bias = alibi_bias - torch.arange(
403
+ 1 - seq_len, 1, dtype=dtype, device=device).view(1, 1, seq_len, 1)
404
+ alibi_bias = alibi_bias.abs().mul(-1)
405
+
406
+ m = torch.arange(1, n_heads + 1, dtype=dtype, device=device)
407
+ m = m.mul(alibi_bias_max / n_heads)
408
+ alibi_bias = alibi_bias * (1. / (2**m.view(1, n_heads, 1, 1)))
409
+ return alibi_bias
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "replit/replit-code-v1-3b",
3
+ "alibi": true,
4
+ "alibi_bias_max": 8,
5
+ "architectures": [
6
+ "ReplitLM"
7
+ ],
8
+ "attn_clip_qkv": null,
9
+ "attn_impl": "torch",
10
+ "attn_pdrop": 0,
11
+ "attn_qk_ln": false,
12
+ "attn_uses_sequence_id": false,
13
+ "auto_map": {
14
+ "AutoConfig": "configuration_replit_lm.ReplitLMConfig",
15
+ "AutoModelForCausalLM": "replit_lm.ReplitLM"
16
+ },
17
+ "d_model": 2560,
18
+ "emb_init_std": null,
19
+ "emb_init_uniform_lim": null,
20
+ "emb_pdrop": 0,
21
+ "embedding_fraction": 1.0,
22
+ "fan_mode": "fan_in",
23
+ "init_device": "cpu",
24
+ "init_div_is_residual": true,
25
+ "init_gain": 0,
26
+ "init_nonlinearity": "relu",
27
+ "init_std": 0.02,
28
+ "logit_scale": null,
29
+ "low_precision_layernorm": true,
30
+ "max_seq_len": 2048,
31
+ "mlp_ratio": 4,
32
+ "model_type": "replit_lm",
33
+ "n_heads": 32,
34
+ "n_layers": 32,
35
+ "no_bias": true,
36
+ "param_init_fn": "kaiming_normal_",
37
+ "prefix_lm": false,
38
+ "resid_pdrop": 0,
39
+ "softmax_scale": null,
40
+ "tokenizer_name": "replit/replit-code-v1-3b",
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.26.1",
43
+ "use_cache": false,
44
+ "verbose": 0,
45
+ "vocab_size": 32768
46
+ }
configuration_replit_lm.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Forked for ReplitLM"""
5
+
6
+ """A HuggingFace-style model configuration."""
7
+
8
+
9
+ from typing import Optional, Tuple, Union
10
+ from transformers import PretrainedConfig
11
+ class ReplitLMConfig(PretrainedConfig):
12
+ model_type = 'replit_lm'
13
+
14
+ def __init__(
15
+ self,
16
+ d_model: int = 2048,
17
+ n_heads: int = 16,
18
+ n_layers: int = 24,
19
+ mlp_ratio: int = 4,
20
+ max_seq_len: int = 2048,
21
+ vocab_size: int = 50368,
22
+ attn_pdrop: float = 0.0,
23
+ resid_pdrop: float = 0.0,
24
+ emb_pdrop: float = 0.0,
25
+ attn_impl: str = 'triton',
26
+ attn_qk_ln: bool = False,
27
+ attn_clip_qkv: Optional[float] = None,
28
+ softmax_scale: Optional[float] = None,
29
+ prefix_lm: Optional[bool] = False,
30
+ attn_uses_sequence_id: Optional[bool] = False,
31
+ alibi: bool = False,
32
+ alibi_bias_max: int = 8,
33
+ init_device: str = 'cpu',
34
+ logit_scale: Optional[Union[float, str]] = None,
35
+ no_bias: bool = False,
36
+ verbose: int = 0,
37
+ param_init_fn: str = 'kaiming_normal_',
38
+ init_div_is_residual: Union[int, float, str, bool] = True,
39
+ init_std: float = 0.02,
40
+ emb_init_std: Optional[float] = None,
41
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float],
42
+ float]] = None,
43
+ init_gain: float = 0,
44
+ fan_mode: str = 'fan_in',
45
+ init_nonlinearity: str = 'relu',
46
+ embedding_fraction: float = 1.0,
47
+ low_precision_layernorm: bool = True,
48
+ use_cache: bool = False,
49
+ **kwargs,
50
+ ):
51
+ """The ReplitLM configuration class.
52
+
53
+ Args:
54
+ d_model (int): The size of the embedding dimension of the model.
55
+ n_heads (int): The number of attention heads.
56
+ n_layers (int): The number of layers in the model.
57
+ mlp_ratio (int): The ratio of the up/down scale in the MLP.
58
+ max_seq_len (int): The maximum sequence length of the model.
59
+ vocab_size (int): The size of the vocabulary.
60
+ attn_pdrop (float): The dropout probability for the attention layers.
61
+ resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
62
+ emb_pdrop (float): The dropout probability for the embedding layer.
63
+ attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
64
+ attn_qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
65
+ attn_clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
66
+ this value.
67
+ softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
68
+ use the default scale of ``1/sqrt(d_keys)``.
69
+ prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
70
+ extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
71
+ can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
72
+ attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
73
+ When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
74
+ which sub-sequence each token belongs to.
75
+ Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
76
+ alibi (bool): Whether to use the alibi bias instead of position embeddings.
77
+ alibi_bias_max (int): The maximum value of the alibi bias.
78
+ init_device (str): The device to use for parameter initialization.
79
+ logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
80
+ no_bias (bool): Whether to use bias in all layers.
81
+ verbose (int): The verbosity level. 0 is silent.
82
+ param_init_fn (str): The parameter initialization scheme to use. One of 'default_', 'baseline_', 'kaiming_uniform_',
83
+ 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 'xavier_normal_'.
84
+ init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
85
+ init_std (float): The standard deviation of the normal distribution used to initialize the model,
86
+ if using the baseline_ parameter initialization scheme.
87
+ emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
88
+ emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
89
+ used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
90
+ init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
91
+ fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
92
+ init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
93
+ embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
94
+ low_precision_layernorm (bool): Whether to use low precision layer normalization.
95
+ use_cache (bool): Whether or not the model should return the last key/values attentions
96
+ """
97
+ self.d_model = d_model
98
+ self.n_heads = n_heads
99
+ self.n_layers = n_layers
100
+ self.mlp_ratio = mlp_ratio
101
+ self.max_seq_len = max_seq_len
102
+ self.vocab_size = vocab_size
103
+ self.attn_pdrop = attn_pdrop
104
+ self.resid_pdrop = resid_pdrop
105
+ self.emb_pdrop = emb_pdrop
106
+ self.attn_impl = attn_impl
107
+ self.attn_qk_ln = attn_qk_ln
108
+ self.attn_clip_qkv = attn_clip_qkv
109
+ self.softmax_scale = softmax_scale
110
+ self.prefix_lm = prefix_lm
111
+ self.attn_uses_sequence_id = attn_uses_sequence_id
112
+ self.alibi = alibi
113
+ self.alibi_bias_max = alibi_bias_max
114
+ self.init_device = init_device
115
+ self.logit_scale = logit_scale
116
+ self.no_bias = no_bias
117
+ self.verbose = verbose
118
+ self.param_init_fn = param_init_fn
119
+ self.init_div_is_residual = init_div_is_residual
120
+ self.init_std = init_std
121
+ self.emb_init_std = emb_init_std
122
+ self.emb_init_uniform_lim = emb_init_uniform_lim
123
+ self.init_std = init_std
124
+ self.init_gain = init_gain
125
+ self.fan_mode = fan_mode
126
+ self.init_nonlinearity = init_nonlinearity
127
+ self.embedding_fraction = embedding_fraction
128
+ self.low_precision_layernorm = low_precision_layernorm
129
+ self.use_cache = use_cache
130
+ if 'name' in kwargs:
131
+ del kwargs['name']
132
+ if 'loss_fn' in kwargs:
133
+ del kwargs['loss_fn']
134
+ super().__init__(**kwargs)
135
+
136
+ self._validate_config()
137
+
138
+ def _validate_config(self):
139
+ if self.d_model % self.n_heads != 0:
140
+ raise ValueError('d_model must be divisible by n_heads')
141
+ if any(prob < 0 or prob > 1
142
+ for prob in [self.attn_pdrop, self.resid_pdrop, self.emb_pdrop]):
143
+ raise ValueError(
144
+ 'attn_pdrop, resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1'
145
+ )
146
+ if self.attn_impl not in ['torch', 'flash', 'triton']:
147
+ raise ValueError(f'Unknown attn_impl={self.attn_impl}')
148
+ if self.prefix_lm and self.attn_impl not in ['torch', 'triton']:
149
+ raise NotImplementedError(
150
+ 'prefix_lm only implemented with torch and triton attention.')
151
+ if self.alibi and self.attn_impl not in ['torch', 'triton']:
152
+ raise NotImplementedError(
153
+ 'alibi only implemented with torch and triton attention.')
154
+ if self.attn_uses_sequence_id and self.attn_impl not in [
155
+ 'torch', 'triton'
156
+ ]:
157
+ raise NotImplementedError(
158
+ 'attn_uses_sequence_id only implemented with torch and triton attention.'
159
+ )
160
+ if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
161
+ raise ValueError(
162
+ 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
163
+ )
164
+ if isinstance(self.logit_scale,
165
+ str) and self.logit_scale != 'inv_sqrt_d_model':
166
+ raise ValueError(
167
+ f"{self.logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
168
+ )
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.26.1",
4
+ "use_cache": false
5
+ }
gpt_blocks.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """GPT Blocks used for the GPT Model."""
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from .attention import MultiheadAttention
12
+ from .low_precision_layernorm import LPLayerNorm
13
+
14
+
15
+ class GPTMLP(nn.Module):
16
+
17
+ def __init__(self,
18
+ d_model: int,
19
+ mlp_ratio: int,
20
+ device: Optional[str] = None):
21
+ super().__init__()
22
+ self.mlp_up = nn.Linear(d_model, mlp_ratio * d_model, device=device)
23
+ self.mlp_act = nn.GELU(approximate='none')
24
+ self.mlp_down = nn.Linear(mlp_ratio * d_model, d_model, device=device)
25
+ self.mlp_down._is_residual = True # type: ignore
26
+
27
+ def forward(self, x):
28
+ return self.mlp_down(self.mlp_act(self.mlp_up(x)))
29
+
30
+
31
+ class GPTBlock(nn.Module):
32
+
33
+ def __init__(self,
34
+ attn_impl: str,
35
+ d_model: int,
36
+ n_heads: int,
37
+ mlp_ratio: int,
38
+ attn_clip_qkv: Optional[float] = None,
39
+ attn_qk_ln: bool = False,
40
+ softmax_scale: Optional[float] = None,
41
+ attn_pdrop: float = 0.0,
42
+ alibi: bool = False,
43
+ resid_pdrop: float = 0.0,
44
+ low_precision_layernorm: bool = False,
45
+ device: Optional[str] = None,
46
+ **kwargs):
47
+ del kwargs # unused, just to capture any extra args from the config
48
+ super().__init__()
49
+
50
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
51
+
52
+ self.ln_1 = layernorm_class(d_model, device=device)
53
+ self.attn = MultiheadAttention(
54
+ attn_impl=attn_impl,
55
+ attn_clip_qkv=attn_clip_qkv,
56
+ attn_qk_ln=attn_qk_ln,
57
+ softmax_scale=softmax_scale,
58
+ attn_pdrop=attn_pdrop,
59
+ d_model=d_model,
60
+ n_heads=n_heads,
61
+ device=device,
62
+ )
63
+ self.ln_2 = layernorm_class(d_model, device=device)
64
+ self.mlp = GPTMLP(
65
+ d_model=d_model,
66
+ mlp_ratio=mlp_ratio,
67
+ device=device,
68
+ )
69
+ self.resid_attn_dropout = nn.Dropout(resid_pdrop)
70
+ self.resid_mlp_dropout = nn.Dropout(resid_pdrop)
71
+
72
+ def forward(
73
+ self,
74
+ x: torch.Tensor,
75
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
76
+ attn_bias: Optional[torch.Tensor] = None,
77
+ attention_mask: Optional[torch.ByteTensor] = None,
78
+ is_causal: bool = True,
79
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
80
+ a = self.ln_1(x)
81
+ b, _, past_key_value = self.attn(a,
82
+ past_key_value=past_key_value,
83
+ attn_bias=attn_bias,
84
+ attention_mask=attention_mask,
85
+ is_causal=is_causal)
86
+ x = x + self.resid_attn_dropout(b)
87
+ m = self.ln_2(x)
88
+ n = self.mlp(m)
89
+ x = x + self.resid_mlp_dropout(n)
90
+ return x, past_key_value
low_precision_layernorm.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class LPLayerNorm(torch.nn.LayerNorm):
6
+ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
7
+ super().__init__(
8
+ normalized_shape=normalized_shape,
9
+ eps=eps,
10
+ elementwise_affine=elementwise_affine,
11
+ device=device,
12
+ dtype=dtype,
13
+ )
14
+
15
+ def forward(self, x):
16
+ module_device = x.device
17
+ downcast_x = _cast_if_autocast_enabled(x)
18
+ downcast_weight = _cast_if_autocast_enabled(
19
+ self.weight) if self.weight is not None else self.weight
20
+ downcast_bias = _cast_if_autocast_enabled(
21
+ self.bias) if self.bias is not None else self.bias
22
+ with torch.autocast(enabled=False, device_type=module_device.type):
23
+ return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
24
+
25
+
26
+ def _cast_if_autocast_enabled(tensor):
27
+ if torch.is_autocast_enabled():
28
+ if tensor.device.type == 'cuda':
29
+ dtype = torch.get_autocast_gpu_dtype()
30
+ elif tensor.device.type == 'cpu':
31
+ dtype = torch.get_autocast_cpu_dtype()
32
+ else:
33
+ raise NotImplementedError()
34
+ return tensor.to(dtype=dtype)
35
+ return tensor
param_init_fns.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import math
4
+ import warnings
5
+ from collections.abc import Sequence
6
+ from functools import partial
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ def torch_default_param_init_fn_(
14
+ module: nn.Module,
15
+ verbose: int = 0,
16
+ **kwargs,
17
+ ):
18
+ del kwargs # unused, just to capture any extra args from the config
19
+ if verbose > 1:
20
+ warnings.warn(
21
+ f"Initializing network using module's reset_parameters attribute")
22
+
23
+ if hasattr(module, 'reset_parameters'):
24
+ module.reset_parameters() # type: ignore
25
+
26
+
27
+ def fused_init_helper_(module: nn.Module, init_fn_):
28
+ # parameter initialization is often based on the parameters shape.
29
+ # If a layer is fused, initialization should be based on the shapes
30
+ # of the original tensor instead of the shape of the fused tensor.
31
+ # Layers which are fused should have the _fused attibute defined.
32
+ # The first element of _fused is the dimension along which the tensor is fused.
33
+ # This is followed by an iterable of split indices."
34
+
35
+ _fused = getattr(module, '_fused', None)
36
+
37
+ if _fused is None:
38
+ raise RuntimeError(f'Internal logic error')
39
+
40
+ dim, splits = _fused
41
+ splits = (0, *splits, module.weight.size(dim)) # type: ignore
42
+ for s, e in zip(splits[:-1], splits[1:]):
43
+ slice_indices = [slice(None)] * module.weight.ndim # type: ignore
44
+ slice_indices[dim] = slice(s, e)
45
+ init_fn_(module.weight[slice_indices]) # type: ignore
46
+
47
+
48
+ def generic_param_init_fn_(
49
+ module: nn.Module,
50
+ init_fn_,
51
+ n_layers: int,
52
+ d_model: Optional[int] = None,
53
+ init_div_is_residual: Union[int, float, str, bool] = True,
54
+ emb_init_std: Optional[float] = None,
55
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
56
+ verbose: int = 0,
57
+ **kwargs,
58
+ ):
59
+ del kwargs # unused, just to capture any extra args from the config
60
+ if verbose > 1:
61
+ warnings.warn(
62
+ f'If model has bias parameters they are initialized to 0.')
63
+
64
+ # enable user to divide _is_residual weights by
65
+ # a value which defaults to math.sqrt(2 * cfg.n_layers)
66
+ init_div_is_residual = init_div_is_residual
67
+
68
+ if init_div_is_residual is False:
69
+ # not used, for pyright
70
+ div_is_residual = 1.0
71
+ elif init_div_is_residual is True:
72
+ div_is_residual = math.sqrt(2 * n_layers)
73
+ elif isinstance(init_div_is_residual, float) or isinstance(
74
+ init_div_is_residual, int):
75
+ div_is_residual = init_div_is_residual
76
+ elif isinstance(init_div_is_residual,
77
+ str) and init_div_is_residual.isnumeric():
78
+ # do not trust YAML parsing to always convert numbers to numbers
79
+ div_is_residual = float(init_div_is_residual)
80
+ else:
81
+ # not used, for pyright
82
+ div_is_residual = 1.0
83
+ raise ValueError(
84
+ f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
85
+ )
86
+
87
+ if init_div_is_residual is not False:
88
+ if verbose > 1:
89
+ warnings.warn(
90
+ f'Initializing _is_residual layers then dividing them by {div_is_residual}.' +
91
+ f'set `init_div_is_residual: false` in model config to disable this.'
92
+ )
93
+
94
+ if isinstance(module, nn.Linear):
95
+ # Linear
96
+ if hasattr(module, '_fused'):
97
+ fused_init_helper_(module, init_fn_)
98
+ else:
99
+ init_fn_(module.weight)
100
+ if module.bias is not None:
101
+ torch.nn.init.zeros_(module.bias)
102
+
103
+ if init_div_is_residual is not False and getattr(
104
+ module, '_is_residual', False):
105
+ with torch.no_grad():
106
+ module.weight.div_(div_is_residual)
107
+
108
+ elif isinstance(module, nn.Embedding):
109
+ # Embedding
110
+ if emb_init_std is not None:
111
+ std = emb_init_std
112
+ if std == 0:
113
+ warnings.warn(f'Embedding layer initialized to 0.')
114
+ emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
115
+ if verbose > 1:
116
+ warnings.warn(
117
+ f'Embedding layer initialized using normal distribution with mean=0 and {std=}.'
118
+ )
119
+ elif emb_init_uniform_lim is not None:
120
+ lim = emb_init_uniform_lim
121
+ if isinstance(lim, Sequence):
122
+ if len(lim) > 2:
123
+ raise ValueError(
124
+ f'Uniform init requires a min and a max limit. User input: {lim}.'
125
+ )
126
+ if lim[0] == lim[1]:
127
+ warnings.warn(f'Embedding layer initialized to {lim[0]}.')
128
+ else:
129
+ if lim == 0:
130
+ warnings.warn(f'Embedding layer initialized to 0.')
131
+ lim = [-lim, lim]
132
+ a, b = lim
133
+ emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
134
+ if verbose > 1:
135
+ warnings.warn(
136
+ f'Embedding layer initialized using uniform distribution in range {lim}.'
137
+ )
138
+ else:
139
+ emb_init_fn_ = init_fn_
140
+
141
+ emb_init_fn_(module.weight)
142
+
143
+ elif isinstance(module, nn.LayerNorm):
144
+ # LayerNorm
145
+ if verbose > 1:
146
+ warnings.warn(
147
+ f'LayerNorm gamma weights are set to 1. If the layer has a bias it is initialized to 0.'
148
+ )
149
+ torch.nn.init.ones_(module.weight)
150
+ if module.bias is not None:
151
+ torch.nn.init.zeros_(module.bias)
152
+
153
+ elif isinstance(module, nn.MultiheadAttention):
154
+ # torch's MultiheadAttention
155
+ if module._qkv_same_embed_dim:
156
+ assert module.in_proj_weight is not None
157
+ assert module.q_proj_weight is None and module.k_proj_weight is None and module.v_proj_weight is None
158
+ assert d_model is not None
159
+ # in_proj_weight is actually 3 layers and should be split up for width based init
160
+ _d = d_model
161
+ splits = (0, _d, 2 * _d, 3 * _d)
162
+ for s, e in zip(splits[:-1], splits[1:]):
163
+ init_fn_(module.in_proj_weight[s:e])
164
+ else:
165
+ assert module.q_proj_weight is not None and module.k_proj_weight is not None and module.v_proj_weight is not None
166
+ assert module.in_proj_weight is None
167
+ init_fn_(module.q_proj_weight)
168
+ init_fn_(module.k_proj_weight)
169
+ init_fn_(module.v_proj_weight)
170
+
171
+ # bias
172
+ if module.in_proj_bias is not None:
173
+ torch.nn.init.zeros_(module.in_proj_bias)
174
+ if module.bias_k is not None:
175
+ torch.nn.init.zeros_(module.bias_k)
176
+ if module.bias_v is not None:
177
+ torch.nn.init.zeros_(module.bias_v)
178
+
179
+ # out proj
180
+ init_fn_(module.out_proj.weight)
181
+ if init_div_is_residual is not False and getattr(
182
+ module.out_proj, '_is_residual', False):
183
+ with torch.no_grad():
184
+ module.out_proj.weight.div_(div_is_residual)
185
+ if module.out_proj.bias is not None:
186
+ torch.nn.init.zeros_(module.out_proj.bias)
187
+
188
+ else:
189
+ for _ in module.parameters(recurse=False):
190
+ # raise error if uninitialized module has any parameters
191
+ raise NotImplementedError(
192
+ f'{module.__class__.__name__} parameters are not initialized by param_init_fn.'
193
+ )
194
+
195
+
196
+ def _normal_init_(std, mean=0.0):
197
+ return partial(torch.nn.init.normal_, mean=mean, std=std)
198
+
199
+
200
+ def _normal_param_init_fn_(
201
+ module: nn.Module,
202
+ std: float,
203
+ n_layers: int,
204
+ d_model: Optional[int] = None,
205
+ init_div_is_residual: Union[int, float, str, bool] = True,
206
+ emb_init_std: Optional[float] = None,
207
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
208
+ verbose: int = 0,
209
+ **kwargs,
210
+ ):
211
+ del kwargs # unused, just to capture any extra args from the config
212
+ init_fn_ = _normal_init_(std=std)
213
+
214
+ if verbose > 1:
215
+ warnings.warn(
216
+ f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
217
+
218
+ generic_param_init_fn_(
219
+ module=module,
220
+ init_fn_=init_fn_,
221
+ d_model=d_model,
222
+ n_layers=n_layers,
223
+ init_div_is_residual=init_div_is_residual,
224
+ emb_init_std=emb_init_std,
225
+ emb_init_uniform_lim=emb_init_uniform_lim,
226
+ verbose=verbose,
227
+ )
228
+
229
+
230
+ def baseline_param_init_fn_(
231
+ module: nn.Module,
232
+ init_std: float,
233
+ n_layers: int,
234
+ d_model: Optional[int] = None,
235
+ init_div_is_residual: Union[int, float, str, bool] = True,
236
+ emb_init_std: Optional[float] = None,
237
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
238
+ verbose: int = 0,
239
+ **kwargs,
240
+ ):
241
+ del kwargs # unused, just to capture any extra args from the config
242
+ if init_std is None:
243
+ raise ValueError(
244
+ 'You must set model.init_std to a float value to use the default initialization scheme.'
245
+ )
246
+ _normal_param_init_fn_(
247
+ module=module,
248
+ std=init_std,
249
+ d_model=d_model,
250
+ n_layers=n_layers,
251
+ init_div_is_residual=init_div_is_residual,
252
+ emb_init_std=emb_init_std,
253
+ emb_init_uniform_lim=emb_init_uniform_lim,
254
+ verbose=verbose,
255
+ )
256
+
257
+
258
+ def small_param_init_fn_(
259
+ module: nn.Module,
260
+ n_layers: int,
261
+ d_model: int,
262
+ init_div_is_residual: Union[int, float, str, bool] = True,
263
+ emb_init_std: Optional[float] = None,
264
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
265
+ verbose: int = 0,
266
+ **kwargs,
267
+ ):
268
+ del kwargs # unused, just to capture any extra args from the config
269
+ # very close to kaiming normal
270
+ # from Transformers without Tears (2019) - Nguyen & Salazar
271
+ std = math.sqrt(2 / (5 * d_model))
272
+ _normal_param_init_fn_(
273
+ module=module,
274
+ std=std,
275
+ d_model=d_model,
276
+ n_layers=n_layers,
277
+ init_div_is_residual=init_div_is_residual,
278
+ emb_init_std=emb_init_std,
279
+ emb_init_uniform_lim=emb_init_uniform_lim,
280
+ verbose=verbose,
281
+ )
282
+
283
+
284
+ def neox_param_init_fn_(
285
+ module: nn.Module,
286
+ n_layers: int,
287
+ d_model: int,
288
+ emb_init_std: Optional[float] = None,
289
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
290
+ verbose: int = 0,
291
+ **kwargs,
292
+ ):
293
+ """From section 2.3.1 of GPT-NeoX-20B:
294
+
295
+ An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
296
+ see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
297
+ and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
298
+ """
299
+ del kwargs # unused, just to capture any extra args from the config
300
+ residual_div = n_layers / math.sqrt(10) # small std / wang std
301
+
302
+ if verbose > 1:
303
+ warnings.warn(f'setting init_div_is_residual to {residual_div}')
304
+
305
+ small_param_init_fn_(
306
+ module=module,
307
+ d_model=d_model,
308
+ n_layers=n_layers,
309
+ init_div_is_residual=residual_div,
310
+ emb_init_std=emb_init_std,
311
+ emb_init_uniform_lim=emb_init_uniform_lim,
312
+ verbose=verbose,
313
+ )
314
+
315
+
316
+ def kaiming_uniform_param_init_fn_(
317
+ module: nn.Module,
318
+ n_layers: int,
319
+ d_model: Optional[int] = None,
320
+ init_div_is_residual: Union[int, float, str, bool] = True,
321
+ emb_init_std: Optional[float] = None,
322
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
323
+ init_gain: float = 0,
324
+ fan_mode: str = 'fan_in',
325
+ init_nonlinearity: str = 'leaky_relu',
326
+ verbose: int = 0,
327
+ **kwargs,
328
+ ):
329
+ del kwargs # unused, just to capture any extra args from the config
330
+
331
+ if verbose > 1:
332
+ warnings.warn(
333
+ f'Using nn.init.kaiming_uniform_ init fn with parameters: ' +
334
+ f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
335
+ )
336
+
337
+ kaiming_uniform_ = partial(nn.init.kaiming_uniform_,
338
+ a=init_gain,
339
+ mode=fan_mode,
340
+ nonlinearity=init_nonlinearity)
341
+
342
+ generic_param_init_fn_(
343
+ module=module,
344
+ init_fn_=kaiming_uniform_,
345
+ d_model=d_model,
346
+ n_layers=n_layers,
347
+ init_div_is_residual=init_div_is_residual,
348
+ emb_init_std=emb_init_std,
349
+ emb_init_uniform_lim=emb_init_uniform_lim,
350
+ verbose=verbose,
351
+ )
352
+
353
+
354
+ def kaiming_normal_param_init_fn_(
355
+ module: nn.Module,
356
+ n_layers: int,
357
+ d_model: Optional[int] = None,
358
+ init_div_is_residual: Union[int, float, str, bool] = True,
359
+ emb_init_std: Optional[float] = None,
360
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
361
+ init_gain: float = 0,
362
+ fan_mode: str = 'fan_in',
363
+ init_nonlinearity: str = 'leaky_relu',
364
+ verbose: int = 0,
365
+ **kwargs,
366
+ ):
367
+ del kwargs # unused, just to capture any extra args from the config
368
+
369
+ if verbose > 1:
370
+ warnings.warn(
371
+ f'Using nn.init.kaiming_normal_ init fn with parameters: ' +
372
+ f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}'
373
+ )
374
+
375
+ kaiming_normal_ = partial(torch.nn.init.kaiming_normal_,
376
+ a=init_gain,
377
+ mode=fan_mode,
378
+ nonlinearity=init_nonlinearity)
379
+
380
+ generic_param_init_fn_(
381
+ module=module,
382
+ init_fn_=kaiming_normal_,
383
+ d_model=d_model,
384
+ n_layers=n_layers,
385
+ init_div_is_residual=init_div_is_residual,
386
+ emb_init_std=emb_init_std,
387
+ emb_init_uniform_lim=emb_init_uniform_lim,
388
+ verbose=verbose,
389
+ )
390
+
391
+
392
+ def xavier_uniform_param_init_fn_(
393
+ module: nn.Module,
394
+ n_layers: int,
395
+ d_model: Optional[int] = None,
396
+ init_div_is_residual: Union[int, float, str, bool] = True,
397
+ emb_init_std: Optional[float] = None,
398
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
399
+ init_gain: float = 0,
400
+ verbose: int = 0,
401
+ **kwargs,
402
+ ):
403
+ del kwargs # unused, just to capture any extra args from the config
404
+ xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
405
+
406
+ if verbose > 1:
407
+ warnings.warn(
408
+ f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' +
409
+ f'gain={init_gain}'
410
+ )
411
+
412
+ generic_param_init_fn_(
413
+ module=module,
414
+ init_fn_=xavier_uniform_,
415
+ d_model=d_model,
416
+ n_layers=n_layers,
417
+ init_div_is_residual=init_div_is_residual,
418
+ emb_init_std=emb_init_std,
419
+ emb_init_uniform_lim=emb_init_uniform_lim,
420
+ verbose=verbose,
421
+ )
422
+
423
+
424
+ def xavier_normal_param_init_fn_(
425
+ module: nn.Module,
426
+ n_layers: int,
427
+ d_model: Optional[int] = None,
428
+ init_div_is_residual: Union[int, float, str, bool] = True,
429
+ emb_init_std: Optional[float] = None,
430
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
431
+ init_gain: float = 0,
432
+ verbose: int = 0,
433
+ **kwargs,
434
+ ):
435
+ xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
436
+
437
+ if verbose > 1:
438
+ warnings.warn(
439
+ f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' +
440
+ f'gain={init_gain}'
441
+ )
442
+
443
+ generic_param_init_fn_(
444
+ module=module,
445
+ init_fn_=xavier_normal_,
446
+ d_model=d_model,
447
+ n_layers=n_layers,
448
+ init_div_is_residual=init_div_is_residual,
449
+ emb_init_std=emb_init_std,
450
+ emb_init_uniform_lim=emb_init_uniform_lim,
451
+ verbose=verbose,
452
+ )
453
+
454
+
455
+ MODEL_INIT_REGISTRY = {
456
+ 'default_': torch_default_param_init_fn_,
457
+ 'baseline_': baseline_param_init_fn_,
458
+ 'kaiming_uniform_': kaiming_uniform_param_init_fn_,
459
+ 'kaiming_normal_': kaiming_normal_param_init_fn_,
460
+ 'neox_init_': neox_param_init_fn_,
461
+ 'small_init_': small_param_init_fn_,
462
+ 'xavier_uniform_': xavier_uniform_param_init_fn_,
463
+ 'xavier_normal_': xavier_normal_param_init_fn_,
464
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6516d02ef00bc903aad7d05dc35607cff7e4c7335d4f1bf424cdcb6695cd3e86
3
+ size 10402658381
replit_lm.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Forked from the MosaicGPT model class from the Mosaic Examples codebase of date May 1st, 2023.
5
+ Permalink: https://github.com/mosaicml/examples/blob/52cd4fef69497f225a034fcd10692f8613732d10/examples/llm/src/models/mosaic_gpt/mosaic_gpt.py
6
+ """
7
+
8
+ """A simple, flexible implementation of a GPT model.
9
+
10
+ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
11
+ """
12
+
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import warnings
18
+
19
+ from transformers import PreTrainedModel
20
+ from transformers.modeling_outputs import CausalLMOutputWithPast
21
+ from typing import List, Optional, Tuple
22
+
23
+ from .attention import attn_bias as module_attn_bias, attn_bias_shape as module_attn_bias_shape
24
+ from .gpt_blocks import GPTBlock
25
+ from .configuration_replit_lm import \
26
+ ReplitLMConfig
27
+ from .param_init_fns import MODEL_INIT_REGISTRY
28
+ from .low_precision_layernorm import LPLayerNorm
29
+
30
+
31
+ class ReplitLM(PreTrainedModel):
32
+ config_class = ReplitLMConfig
33
+ base_model_prefix = 'replit_lm'
34
+
35
+ def __init__(self, config: ReplitLMConfig):
36
+ super().__init__(config)
37
+
38
+ if config.attn_impl == 'flash' and config.alibi:
39
+ raise RuntimeError("ALiBi is not supported with flash attention. Please use triton or torch.")
40
+
41
+ self.attn_impl = config.attn_impl
42
+ self.prefix_lm = config.prefix_lm
43
+ self.attn_uses_sequence_id = config.attn_uses_sequence_id
44
+ self.alibi = config.alibi
45
+ self.alibi_bias_max = config.alibi_bias_max
46
+
47
+ layernorm_class = LPLayerNorm if config.low_precision_layernorm else nn.LayerNorm
48
+
49
+ # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
50
+ # both report this helping with stabilizing training
51
+ self.embedding_fraction = config.embedding_fraction
52
+
53
+ self.transformer = nn.ModuleDict({
54
+ 'wte':
55
+ nn.Embedding(config.vocab_size,
56
+ config.d_model,
57
+ device=config.init_device)
58
+ })
59
+ if not self.alibi:
60
+ self.transformer.update({
61
+ 'wpe':
62
+ nn.Embedding(config.max_seq_len,
63
+ config.d_model,
64
+ device=config.init_device)
65
+ })
66
+ self.transformer.update({'emb_drop': nn.Dropout(config.emb_pdrop)})
67
+ self.transformer.update({
68
+ 'blocks':
69
+ nn.ModuleList([
70
+ GPTBlock(device=config.init_device,
71
+ **config.to_dict())
72
+ for _ in range(config.n_layers)
73
+ ])
74
+ })
75
+ self.transformer.update({
76
+ 'ln_f': layernorm_class(config.d_model, device=config.init_device)
77
+ })
78
+
79
+ # enables scaling output logits; similar to a softmax "temperature"
80
+ # PaLM paper uses scale 1/sqrt(config.d_model)
81
+ self.logit_scale = None
82
+ if config.logit_scale is not None:
83
+ logit_scale = config.logit_scale
84
+ if isinstance(logit_scale, str):
85
+ if logit_scale == 'inv_sqrt_d_model':
86
+ logit_scale = 1 / math.sqrt(config.d_model)
87
+ else:
88
+ raise ValueError(
89
+ f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
90
+ )
91
+ self.logit_scale = logit_scale
92
+
93
+ if config.init_device != 'meta':
94
+ print(
95
+ f'You are using {config.init_device=}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
96
+ )
97
+ self.apply(self.param_init_fn)
98
+
99
+ self.is_causal = not self.prefix_lm
100
+
101
+ # define attn mask
102
+ self._attn_bias_initialized = False
103
+ self.attn_bias = None
104
+ self.attn_bias_shape = module_attn_bias_shape(
105
+ self.attn_impl,
106
+ config.n_heads,
107
+ config.max_seq_len,
108
+ self.alibi,
109
+ prefix_lm=self.prefix_lm,
110
+ causal=self.is_causal,
111
+ use_sequence_id=self.attn_uses_sequence_id)
112
+
113
+ if config.no_bias:
114
+ for module in self.modules():
115
+ if hasattr(module, 'bias') and isinstance(
116
+ module.bias, nn.Parameter):
117
+ if config.verbose:
118
+ print(f'Removing bias ({module.bias}) from {module}.')
119
+ module.register_parameter('bias', None)
120
+
121
+ if config.verbose and config.verbose > 2:
122
+ print(self)
123
+
124
+ @torch.no_grad()
125
+ def _attn_bias(self,
126
+ device,
127
+ dtype,
128
+ attention_mask: Optional[torch.ByteTensor] = None,
129
+ prefix_mask: Optional[torch.ByteTensor] = None,
130
+ sequence_id: Optional[torch.LongTensor] = None):
131
+ if not self._attn_bias_initialized:
132
+ if self.attn_bias_shape:
133
+ self.attn_bias = torch.zeros(self.attn_bias_shape,
134
+ device=device,
135
+ dtype=dtype)
136
+ self.attn_bias = module_attn_bias(
137
+ self.attn_impl,
138
+ self.attn_bias,
139
+ self.config.n_heads,
140
+ self.config.max_seq_len,
141
+ causal=self.is_causal,
142
+ alibi=self.alibi,
143
+ alibi_bias_max=self.alibi_bias_max)
144
+ self._attn_bias_initialized = True
145
+
146
+ # flash does not support prefix_lm and will incorporate any
147
+ # attention_mask inside the attention module
148
+ if self.attn_impl == 'flash':
149
+ return self.attn_bias, attention_mask
150
+
151
+ attn_bias = self.attn_bias
152
+
153
+ # If using torch or triton, we incorporate the prefix_mask (if appropriate)
154
+ if self.prefix_lm:
155
+ assert isinstance(attn_bias, torch.Tensor) # pyright
156
+ assert isinstance(prefix_mask, torch.Tensor) # pyright
157
+ attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
158
+
159
+ # If using torch or triton, we incorporate sequence_id (if appropriate)
160
+ if self.attn_uses_sequence_id and sequence_id is not None:
161
+ assert isinstance(attn_bias, torch.Tensor) # pyright
162
+ attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
163
+
164
+ # If using torch or triton, we incorporate attention_mask. This will output
165
+ # None in place of attention_mask since it will not be further needed in the
166
+ # attention modules.
167
+ if attention_mask is not None:
168
+ s_k = attention_mask.shape[-1]
169
+ if attn_bias is None:
170
+ attn_bias = torch.zeros((1, 1, 1, s_k),
171
+ device=device,
172
+ dtype=dtype)
173
+ else:
174
+ attn_bias = attn_bias[:, :, :, -s_k:]
175
+ if prefix_mask is not None and (attention_mask.shape !=
176
+ prefix_mask.shape):
177
+ raise ValueError(
178
+ f'attention_mask shape={attention_mask.shape} ' +\
179
+ f'and prefix_mask shape={prefix_mask.shape} are not equal.'
180
+ )
181
+ min_val = torch.finfo(attn_bias.dtype).min
182
+ attn_bias = attn_bias.masked_fill(
183
+ ~attention_mask.view(-1, 1, 1, s_k), min_val)
184
+
185
+ return attn_bias, None
186
+
187
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor,
188
+ prefix_mask: torch.Tensor):
189
+ s_k, s_q = attn_bias.shape[-2:]
190
+ if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len):
191
+ raise ValueError(
192
+ 'attn_bias does not match the expected shape. ' +\
193
+ f'The last two dimensions should both be {self.config.max_length} ' +\
194
+ f'but are {s_k} and {s_q}.'
195
+ )
196
+ seq_len = prefix_mask.shape[-1]
197
+ if seq_len > self.config.max_seq_len:
198
+ raise ValueError(
199
+ f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
200
+ )
201
+
202
+ # select seq_len subset of attn mask
203
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
204
+
205
+ # Mix the causal max and the bidirectional mask to get the full
206
+ # allowable attention (i.e. full = not accounting for padding yet)
207
+ causal = torch.tril(
208
+ torch.ones((seq_len, seq_len),
209
+ dtype=torch.bool,
210
+ device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
211
+ prefix = prefix_mask.view(-1, 1, 1, seq_len)
212
+ cannot_attend = ~torch.logical_or(causal, prefix.bool())
213
+
214
+ min_val = torch.finfo(attn_bias.dtype).min
215
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
216
+
217
+ return attn_bias
218
+
219
+ def _apply_sequence_id(self, attn_bias: torch.Tensor,
220
+ sequence_id: torch.LongTensor):
221
+ seq_len = sequence_id.shape[-1]
222
+ if seq_len > self.config.max_seq_len:
223
+ raise ValueError(
224
+ f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
225
+ )
226
+
227
+ # select seq_len subset of attn mask
228
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
229
+
230
+ # Restrict attention to tokens that share the same value
231
+ # in sequence_id
232
+ cannot_attend = torch.logical_not(
233
+ torch.eq(sequence_id.view(-1, seq_len, 1),
234
+ sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
235
+ min_val = torch.finfo(attn_bias.dtype).min
236
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
237
+
238
+ return attn_bias
239
+
240
+ def forward(
241
+ self,
242
+ input_ids: torch.LongTensor,
243
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
244
+ attention_mask: Optional[torch.ByteTensor] = None,
245
+ prefix_mask: Optional[torch.ByteTensor] = None,
246
+ sequence_id: Optional[torch.LongTensor] = None,
247
+ return_dict: Optional[bool] = None,
248
+ output_attentions: Optional[bool] = None,
249
+ output_hidden_states: Optional[bool] = None,
250
+ use_cache: Optional[bool] = None):
251
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
252
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
253
+
254
+ # These args are passed in by keyword in huggingface's generate function
255
+ # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
256
+ # but have not yet been fully implemented in ReplitLM
257
+ if not return_dict:
258
+ raise NotImplementedError(
259
+ 'return_dict False is not implemented yet for ReplitLM')
260
+ if output_attentions:
261
+ raise NotImplementedError(
262
+ 'output_attentions is not implemented yet for ReplitLM')
263
+
264
+ if attention_mask is not None and attention_mask[:, 0].sum(
265
+ ) != attention_mask.shape[0] and self.training:
266
+ raise NotImplementedError(
267
+ 'ReplitLM does not support training with left padding.')
268
+
269
+ if self.prefix_lm and prefix_mask is None:
270
+ raise ValueError(
271
+ 'prefix_mask is a required argument when ReplitLM is configured with prefix_lm=True.'
272
+ )
273
+
274
+ if self.training:
275
+ if self.attn_uses_sequence_id and sequence_id is None:
276
+ raise ValueError(
277
+ 'sequence_id is a required argument when ReplitLM is configured with attn_uses_sequence_id=True ' +\
278
+ 'and the model is in train mode.'
279
+ )
280
+ elif (self.attn_uses_sequence_id is False) and (sequence_id
281
+ is not None):
282
+ warnings.warn(
283
+ 'ReplitLM received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' +\
284
+ 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'
285
+ )
286
+
287
+ S = input_ids.size(1)
288
+
289
+ assert (
290
+ S <= self.config.max_seq_len
291
+ ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
292
+
293
+ tok_emb = self.transformer.wte(input_ids) # type: ignore
294
+ if self.alibi:
295
+ x = tok_emb
296
+ else:
297
+ past_position = 0
298
+ if past_key_values is not None:
299
+ if len(past_key_values) != self.config.n_layers:
300
+ raise ValueError(
301
+ f'past_key_values must provide a past_key_value for each attention ' +\
302
+ f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
303
+ )
304
+ # get the key tensor whose spec should be (batch, seq, dim), and
305
+ # collect the `seq`, so that the position embedding is shifted
306
+ past_position = past_key_values[0][0].size(1)
307
+
308
+ if S + past_position > self.config.max_seq_len:
309
+ raise ValueError(
310
+ f'Cannot forward input with past sequence length {past_position} and current sequence length '
311
+ f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.'
312
+ )
313
+ pos = torch.arange(past_position,
314
+ S + past_position,
315
+ dtype=torch.long,
316
+ device=input_ids.device).unsqueeze(0)
317
+ if attention_mask is not None:
318
+ # adjust the position indices to account for padding tokens
319
+ pos = torch.clamp(pos - torch.cumsum(
320
+ (~attention_mask).to(torch.int32), dim=1)[:,
321
+ past_position:],
322
+ min=0)
323
+
324
+ pos_emb = self.transformer.wpe(pos) # type: ignore
325
+ x = tok_emb + pos_emb
326
+
327
+ if self.embedding_fraction == 1:
328
+ x = self.transformer.emb_drop(x) # type: ignore
329
+ else:
330
+ # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
331
+ x_shrunk = (x * self.embedding_fraction) + (
332
+ x.detach() * (1 - self.embedding_fraction))
333
+ assert isinstance(self.transformer.emb_drop, nn.Module) # pyright
334
+ x = self.transformer.emb_drop(x_shrunk)
335
+
336
+ attn_bias, attention_mask = self._attn_bias(
337
+ device=x.device,
338
+ dtype=x.dtype,
339
+ attention_mask=attention_mask,
340
+ prefix_mask=prefix_mask,
341
+ sequence_id=sequence_id)
342
+
343
+ # initialize the past key values cache if it should be used
344
+ if use_cache and past_key_values is None:
345
+ past_key_values = [() for _ in range(self.config.n_layers)
346
+ ] # type: ignore
347
+
348
+ all_hidden_states = () if output_hidden_states else None
349
+ for b_idx, block in enumerate(self.transformer.blocks): # type: ignore
350
+ if output_hidden_states:
351
+ assert all_hidden_states is not None # pyright
352
+ all_hidden_states = all_hidden_states + (x,)
353
+ past_key_value = past_key_values[
354
+ b_idx] if past_key_values is not None else None
355
+ x, past_key_value = block(x,
356
+ past_key_value=past_key_value,
357
+ attn_bias=attn_bias,
358
+ attention_mask=attention_mask,
359
+ is_causal=self.is_causal)
360
+ if past_key_values is not None:
361
+ past_key_values[b_idx] = past_key_value
362
+
363
+ x = self.transformer.ln_f(x) # type: ignore
364
+
365
+ # output embedding weight tied to input embedding
366
+ assert isinstance(self.transformer.wte, nn.Module) # pyright
367
+ assert isinstance(self.transformer.wte.weight, torch.Tensor) # pyright
368
+ logits = F.linear(x, self.transformer.wte.weight, None)
369
+
370
+ if self.logit_scale is not None:
371
+ if self.logit_scale == 0:
372
+ warnings.warn(
373
+ f'Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.'
374
+ )
375
+ logits *= self.logit_scale
376
+
377
+ return CausalLMOutputWithPast(logits=logits,
378
+ past_key_values=past_key_values,
379
+ hidden_states=all_hidden_states)
380
+
381
+ # Param Initialization, needed for device='meta' fast initialization
382
+ def param_init_fn(self, module):
383
+ init_fn_name = self.config.param_init_fn
384
+ if self.config.verbose > 1:
385
+ warnings.warn(f'Using {init_fn_name} initialization.')
386
+ MODEL_INIT_REGISTRY[init_fn_name](module=module,
387
+ **self.config.to_dict())
388
+
389
+ # FSDP Wrap function
390
+ def fsdp_wrap_fn(self, module):
391
+ return isinstance(module, GPTBlock)
392
+
393
+ # Activation Checkpointing
394
+ def activation_checkpointing_fn(self, module):
395
+ return isinstance(module, GPTBlock)
396
+
397
+ def prepare_inputs_for_generation(self,
398
+ input_ids,
399
+ past_key_values=None,
400
+ inputs_embeds=None,
401
+ **kwargs):
402
+ if inputs_embeds is not None:
403
+ raise NotImplementedError(
404
+ 'inputs_embeds is not implemented for ReplitLM yet')
405
+
406
+ attention_mask = kwargs['attention_mask'].bool()
407
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
408
+ raise NotImplementedError(
409
+ 'ReplitLM does not support generation with right padding.')
410
+
411
+ if self.attn_uses_sequence_id and self.training:
412
+ sequence_id = torch.zeros_like(input_ids[:1])
413
+ else:
414
+ sequence_id = None
415
+
416
+ if past_key_values is not None:
417
+ input_ids = input_ids[:, -1].unsqueeze(-1)
418
+
419
+ if self.prefix_lm:
420
+ # Leverage a convenience of sequential generation!
421
+ prefix_mask = torch.ones_like(attention_mask)
422
+ # This requires that we're using the cache
423
+ if kwargs.get('use_cache') == False:
424
+ raise NotImplementedError(
425
+ 'ReplitLM with prefix_lm=True does not support use_cache=False.'
426
+ )
427
+ else:
428
+ prefix_mask = None
429
+
430
+ return {
431
+ 'input_ids': input_ids,
432
+ 'attention_mask': attention_mask,
433
+ 'prefix_mask': prefix_mask,
434
+ 'sequence_id': sequence_id,
435
+ 'past_key_values': past_key_values,
436
+ 'use_cache': kwargs.get('use_cache', True),
437
+ }
438
+
439
+ @staticmethod
440
+ def _reorder_cache(past_key_values, beam_idx):
441
+ """Used by HuggingFace generate when using beam search with kv-caching.
442
+
443
+ See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
444
+ for an example in transformers.
445
+ """
446
+ reordered_past = []
447
+ for layer_past in past_key_values:
448
+ reordered_past += [
449
+ tuple(
450
+ past_state.index_select(0, beam_idx)
451
+ for past_state in layer_past)
452
+ ]
453
+ return reordered_past
replit_lm_tokenizer.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Forked from the file src/transformers/models/bert_generation/tokenization_bert_generation.py from the HuggingFace Transformers library.
17
+ Permalink: https://github.com/huggingface/transformers/blob/04ab5605fbb4ef207b10bf2772d88c53fc242e83/src/transformers/models/bert_generation/tokenization_bert_generation.py
18
+
19
+ Class is modified for compatibility with custom vocabulary and to achieve desired encode/decode behavior for Replit Code v1.3b model.
20
+ """
21
+
22
+ """ Tokenizer class for ReplitLM"""
23
+
24
+
25
+ import os
26
+ import sentencepiece as spm
27
+ from shutil import copyfile
28
+ from transformers import PreTrainedTokenizer
29
+ from typing import Any, Dict, List, Optional, Tuple
30
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
31
+
32
+
33
+ class ReplitLMTokenizer(PreTrainedTokenizer):
34
+ """
35
+ Construct a ReplitLMTokenizer tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
36
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods.
37
+
38
+ Args:
39
+ vocab_file (`str`):
40
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
41
+ contains the vocabulary necessary to instantiate a tokenizer.
42
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
43
+ The end of sequence token.
44
+ bos_token (`str`, *optional*, defaults to `None`):
45
+ The begin of sequence token.
46
+ unk_token (`str`, *optional*, defaults to `"<|unk|>"`):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ pad_token (`str`, *optional*, defaults to `"<|pad|>"`):
50
+ The token used for padding, for example when batching sequences of different lengths.
51
+ sp_model_kwargs (`dict`, *optional*):
52
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
53
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
54
+ to set:
55
+ - `enable_sampling`: Enable subword regularization.
56
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
57
+ - `nbest_size = {0,1}`: No sampling is performed.
58
+ - `nbest_size > 1`: samples from the nbest_size results.
59
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
60
+ using forward-filtering-and-backward-sampling algorithm.
61
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
62
+ BPE-dropout.
63
+ """
64
+
65
+ vocab_files_names = VOCAB_FILES_NAMES
66
+ prefix_tokens: List[int] = []
67
+ model_input_names = ["input_ids", "attention_mask"]
68
+
69
+ def __init__(
70
+ self,
71
+ vocab_file,
72
+ bos_token=None,
73
+ eos_token="<|endoftext|>",
74
+ unk_token="<|unk|>",
75
+ pad_token="<|pad|>",
76
+ sep_token=None,
77
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
78
+ **kwargs,
79
+ ) -> None:
80
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
81
+
82
+ # Add extra_ids to the special token list
83
+ super().__init__(
84
+ bos_token=bos_token,
85
+ eos_token=eos_token,
86
+ unk_token=unk_token,
87
+ pad_token=pad_token,
88
+ sep_token=sep_token,
89
+ sp_model_kwargs=self.sp_model_kwargs,
90
+ **kwargs,
91
+ )
92
+
93
+ self.vocab_file = vocab_file
94
+
95
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
96
+ self.sp_model.Load(vocab_file)
97
+
98
+ @property
99
+ def vocab_size(self):
100
+ return self.sp_model.get_piece_size()
101
+
102
+ def get_vocab(self):
103
+ vocab = {self.convert_ids_to_tokens(
104
+ i): i for i in range(self.vocab_size)}
105
+ vocab.update(self.added_tokens_encoder)
106
+ return vocab
107
+
108
+ def __getstate__(self):
109
+ state = self.__dict__.copy()
110
+ state["sp_model"] = None
111
+ return state
112
+
113
+ def __setstate__(self, d):
114
+ self.__dict__ = d
115
+
116
+ # for backward compatibility
117
+ if not hasattr(self, "sp_model_kwargs"):
118
+ self.sp_model_kwargs = {}
119
+
120
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
121
+ self.sp_model.load(self.vocab_file)
122
+
123
+ def _tokenize(self, text: str) -> List[str]:
124
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
125
+ return self.sp_model.encode(text, out_type=str)
126
+
127
+ def _convert_token_to_id(self, token):
128
+ """Converts a token (str) in an id using the vocab."""
129
+ return self.sp_model.piece_to_id(token)
130
+
131
+ def _convert_id_to_token(self, index):
132
+ """Converts an index (integer) in a token (str) using the vocab."""
133
+ token = self.sp_model.id_to_piece(index)
134
+ return token
135
+
136
+ def convert_tokens_to_string(self, tokens):
137
+ """Converts a sequence of tokens (string) in a single string."""
138
+ return self.sp_model.decode(tokens)
139
+
140
+ def save_vocabulary(self,
141
+ save_directory: str,
142
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
143
+
144
+ if not os.path.isdir(save_directory):
145
+ raise ValueError(
146
+ f"Vocabulary path ({save_directory}) should be a directory")
147
+
148
+ out_vocab_file = os.path.join(
149
+ save_directory, (filename_prefix + "-" if filename_prefix else "") +
150
+ VOCAB_FILES_NAMES["vocab_file"])
151
+
152
+ if os.path.abspath(
153
+ self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(
154
+ self.vocab_file):
155
+ copyfile(self.vocab_file, out_vocab_file)
156
+ elif not os.path.isfile(self.vocab_file):
157
+ with open(out_vocab_file, "wb") as fi:
158
+ content_spiece_model = self.sp_model.serialized_model_proto()
159
+ fi.write(content_spiece_model)
160
+
161
+ return (out_vocab_file, )
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "eos_token": "<|endoftext|>",
3
+ "pad_token": "<|pad|>",
4
+ "unk_token": "<|unk|>"
5
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e1ba8b7df0701723d2d901c7a42182fe77bf0045173f2cdb474ca6ea3eb1c02
3
+ size 707660
tokenizer_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "replit_lm_tokenizer.ReplitLMTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "bos_token": null,
9
+ "clean_up_tokenization_spaces": false,
10
+ "eos_token": "<|endoftext|>",
11
+ "model_max_length": 2048,
12
+ "pad_token": "<|pad|>",
13
+ "padding_side": "right",
14
+ "sep_token": null,
15
+ "sp_model_kwargs": {},
16
+ "tokenizer_class": "ReplitLMTokenizer",
17
+ "unk_token": "<|unk|>"
18
+ }