saintyboy commited on
Commit
5374f24
1 Parent(s): 4028f2b

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +331 -0
model.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import inspect
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+
9
+ class LayerNorm(nn.Module):
10
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
11
+
12
+ def __init__(self, ndim, bias):
13
+ super().__init__()
14
+ self.weight = nn.Parameter(torch.ones(ndim))
15
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
16
+
17
+ def forward(self, input):
18
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
19
+
20
+ class CausalSelfAttention(nn.Module):
21
+
22
+ def __init__(self, config):
23
+ super().__init__()
24
+ assert config.n_embd % config.n_head == 0
25
+ # key, query, value projections for all heads, but in a batch
26
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
27
+ # output projection
28
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
29
+ # regularization
30
+ self.attn_dropout = nn.Dropout(config.dropout)
31
+ self.resid_dropout = nn.Dropout(config.dropout)
32
+ self.n_head = config.n_head
33
+ self.n_embd = config.n_embd
34
+ self.dropout = config.dropout
35
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
36
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
37
+ if not self.flash:
38
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
39
+ # causal mask to ensure that attention is only applied to the left in the input sequence
40
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
41
+ .view(1, 1, config.block_size, config.block_size))
42
+
43
+ def forward(self, x):
44
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
45
+
46
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
47
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
48
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
49
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
50
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
51
+
52
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
53
+ if self.flash:
54
+ # efficient attention using Flash Attention CUDA kernels
55
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
56
+ else:
57
+ # manual implementation of attention
58
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
59
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
60
+ att = F.softmax(att, dim=-1)
61
+ att = self.attn_dropout(att)
62
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
63
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
64
+
65
+ # output projection
66
+ y = self.resid_dropout(self.c_proj(y))
67
+ return y
68
+
69
+ class MLP(nn.Module):
70
+
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
74
+ self.gelu = nn.GELU()
75
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
76
+ self.dropout = nn.Dropout(config.dropout)
77
+
78
+ def forward(self, x):
79
+ x = self.c_fc(x)
80
+ x = self.gelu(x)
81
+ x = self.c_proj(x)
82
+ x = self.dropout(x)
83
+ return x
84
+
85
+ class Block(nn.Module):
86
+
87
+ def __init__(self, config):
88
+ super().__init__()
89
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
90
+ self.attn = CausalSelfAttention(config)
91
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
92
+ self.mlp = MLP(config)
93
+
94
+ def forward(self, x):
95
+ x = x + self.attn(self.ln_1(x))
96
+ x = x + self.mlp(self.ln_2(x))
97
+ return x
98
+
99
+ @dataclass
100
+ class GPTConfig:
101
+ block_size: int = 1024
102
+ vocab_size: int = 100265 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
103
+ n_layer: int = 24
104
+ n_head: int = 16
105
+ n_embd: int = 1024
106
+ dropout: float = 0.0
107
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
108
+
109
+ class GPT(nn.Module):
110
+
111
+ def __init__(self, config):
112
+ super().__init__()
113
+ assert config.vocab_size is not None
114
+ assert config.block_size is not None
115
+ self.config = config
116
+
117
+ self.transformer = nn.ModuleDict(dict(
118
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
119
+ wpe = nn.Embedding(config.block_size, config.n_embd),
120
+ drop = nn.Dropout(config.dropout),
121
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
122
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
123
+ ))
124
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
125
+ # with weight tying when using torch.compile() some warnings get generated:
126
+ # "UserWarning: functional_call was passed multiple values for tied weights.
127
+ # This behavior is deprecated and will be an error in future versions"
128
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
129
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
130
+
131
+ # init all weights
132
+ self.apply(self._init_weights)
133
+ # apply special scaled init to the residual projections, per GPT-2 paper
134
+ for pn, p in self.named_parameters():
135
+ if pn.endswith('c_proj.weight'):
136
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
137
+
138
+ # report number of parameters
139
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
140
+
141
+ def get_num_params(self, non_embedding=True):
142
+ """
143
+ Return the number of parameters in the model.
144
+ For non-embedding count (default), the position embeddings get subtracted.
145
+ The token embeddings would too, except due to the parameter sharing these
146
+ params are actually used as weights in the final layer, so we include them.
147
+ """
148
+ n_params = sum(p.numel() for p in self.parameters())
149
+ if non_embedding:
150
+ n_params -= self.transformer.wpe.weight.numel()
151
+ return n_params
152
+
153
+ def _init_weights(self, module):
154
+ if isinstance(module, nn.Linear):
155
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
156
+ if module.bias is not None:
157
+ torch.nn.init.zeros_(module.bias)
158
+ elif isinstance(module, nn.Embedding):
159
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
160
+
161
+ def forward(self, idx, targets=None):
162
+ device = idx.device
163
+ b, t = idx.size()
164
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
165
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
166
+
167
+ # forward the GPT model itself
168
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
169
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
170
+ x = self.transformer.drop(tok_emb + pos_emb)
171
+ for block in self.transformer.h:
172
+ x = block(x)
173
+ x = self.transformer.ln_f(x)
174
+
175
+ if targets is not None:
176
+ # if we are given some desired targets also calculate the loss
177
+ logits = self.lm_head(x)
178
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
179
+ else:
180
+ # inference-time mini-optimization: only forward the lm_head on the very last position
181
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
182
+ loss = None
183
+
184
+ return logits, loss
185
+
186
+ def crop_block_size(self, block_size):
187
+ # model surgery to decrease the block size if necessary
188
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
189
+ # but want to use a smaller block size for some smaller, simpler model
190
+ assert block_size <= self.config.block_size
191
+ self.config.block_size = block_size
192
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
193
+ for block in self.transformer.h:
194
+ if hasattr(block.attn, 'bias'):
195
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
196
+
197
+ @classmethod
198
+ def from_pretrained(cls, model_type, override_args=None):
199
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
200
+ override_args = override_args or {} # default to empty dict
201
+ # only dropout can be overridden see more notes below
202
+ assert all(k == 'dropout' for k in override_args)
203
+ from transformers import GPT2LMHeadModel
204
+ print("loading weights from pretrained gpt: %s" % model_type)
205
+
206
+ # n_layer, n_head and n_embd are determined from model_type
207
+ config_args = {
208
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
209
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
210
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
211
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
212
+ }[model_type]
213
+ print("forcing vocab_size=50257, block_size=1024, bias=True")
214
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
215
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
216
+ config_args['bias'] = True # always True for GPT model checkpoints
217
+ # we can override the dropout rate, if desired
218
+ if 'dropout' in override_args:
219
+ print(f"overriding dropout rate to {override_args['dropout']}")
220
+ config_args['dropout'] = override_args['dropout']
221
+ # create a from-scratch initialized minGPT model
222
+ config = GPTConfig(**config_args)
223
+ model = GPT(config)
224
+ sd = model.state_dict()
225
+ sd_keys = sd.keys()
226
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
227
+
228
+ # init a huggingface/transformers model
229
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
230
+ sd_hf = model_hf.state_dict()
231
+
232
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
233
+ sd_keys_hf = sd_hf.keys()
234
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
235
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
236
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
237
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
238
+ # this means that we have to transpose these weights when we import them
239
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
240
+ for k in sd_keys_hf:
241
+ if any(k.endswith(w) for w in transposed):
242
+ # special treatment for the Conv1D weights we need to transpose
243
+ assert sd_hf[k].shape[::-1] == sd[k].shape
244
+ with torch.no_grad():
245
+ sd[k].copy_(sd_hf[k].t())
246
+ else:
247
+ # vanilla copy over the other parameters
248
+ assert sd_hf[k].shape == sd[k].shape
249
+ with torch.no_grad():
250
+ sd[k].copy_(sd_hf[k])
251
+
252
+ return model
253
+
254
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
255
+ # start with all of the candidate parameters
256
+ param_dict = {pn: p for pn, p in self.named_parameters()}
257
+ # filter out those that do not require grad
258
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
259
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
260
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
261
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
262
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
263
+ optim_groups = [
264
+ {'params': decay_params, 'weight_decay': weight_decay},
265
+ {'params': nodecay_params, 'weight_decay': 0.0}
266
+ ]
267
+ num_decay_params = sum(p.numel() for p in decay_params)
268
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
269
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
270
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
271
+ # Create AdamW optimizer and use the fused version if it is available
272
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
273
+ use_fused = fused_available and device_type == 'cuda'
274
+ extra_args = dict(fused=True) if use_fused else dict()
275
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
276
+ print(f"using fused AdamW: {use_fused}")
277
+
278
+ return optimizer
279
+
280
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
281
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
282
+ # first estimate the number of flops we do per iteration.
283
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
284
+ N = self.get_num_params()
285
+ cfg = self.config
286
+ L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
287
+ flops_per_token = 6*N + 12*L*H*Q*T
288
+ flops_per_fwdbwd = flops_per_token * T
289
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
290
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
291
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
292
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
293
+ mfu = flops_achieved / flops_promised
294
+ return mfu
295
+
296
+ @torch.no_grad()
297
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, top_p=0.95, repetition_penalty=1.2, eor_token_id=None):
298
+ generated = idx
299
+ for _ in range(max_new_tokens):
300
+ idx_cond = generated if generated.size(1) <= self.config.block_size else generated[:, -self.config.block_size:]
301
+ logits, _ = self(idx_cond)
302
+ logits = logits[:, -1, :] / temperature
303
+
304
+ if top_k is not None:
305
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
306
+ logits[logits < v[:, [-1]]] = -float('Inf')
307
+
308
+ if top_p < 1.0:
309
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
310
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
311
+ sorted_indices_to_remove = cumulative_probs > top_p
312
+ if sorted_indices_to_remove[:, 1:].sum().item() > 0:
313
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
314
+ sorted_indices_to_remove[:, 0] = 0
315
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
316
+ logits[:, indices_to_remove] = -float('Inf')
317
+
318
+ probs = F.softmax(logits, dim=-1)
319
+
320
+ if repetition_penalty != 1.0:
321
+ for i in range(generated.size(1)):
322
+ token_id = generated[0, i]
323
+ probs[0, token_id] /= repetition_penalty
324
+
325
+ idx_next = torch.multinomial(probs, num_samples=1)
326
+ generated = torch.cat((generated, idx_next), dim=1)
327
+
328
+ if eor_token_id is not None and idx_next.item() == eor_token_id:
329
+ break
330
+
331
+ return generated