asharsha30 commited on
Commit
0153fea
·
verified ·
1 Parent(s): 8aa0630

Upload Train_GPT2_diff.txt

Browse files
Files changed (1) hide show
  1. Train_GPT2_diff.txt +298 -0
Train_GPT2_diff.txt ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 35a36
2
+ > from hellaswag import render_example, iterate_examples
3
+ 36a38,39
4
+ > import torch._dynamo
5
+ > torch._dynamo.config.suppress_errors = True
6
+ 48c51,54
7
+ < class CausalSelfAttention(nn.Module):
8
+ ---
9
+ > class NewGELU(nn.Module):
10
+ > """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI"""
11
+ > def forward(self, input):
12
+ > return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
13
+ 49a56,79
14
+ > # Rotary Position Embedding
15
+ > def apply_rotary_pos_emb(q, k, sin, cos):
16
+ > q_embed = (q * cos) + rotate_half(q) * sin
17
+ > k_embed = (k * cos) + rotate_half(k) * sin
18
+ > return q_embed, k_embed
19
+ >
20
+ > def rotate_half(x):
21
+ > x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
22
+ > return torch.cat((-x2, x1), dim=-1)
23
+ >
24
+ > class RotaryEmbedding(nn.Module):
25
+ > def __init__(self, dim):
26
+ > super().__init__()
27
+ > inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
28
+ > self.register_buffer("inv_freq", inv_freq)
29
+ >
30
+ > def forward(self, seq_len, device):
31
+ > t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
32
+ > freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
33
+ > emb = torch.cat((freqs, freqs), dim=-1)
34
+ > sin, cos = emb.sin(), emb.cos()
35
+ > return sin, cos
36
+ >
37
+ > class CausalSelfAttention(nn.Module):
38
+ 53,58d82
39
+ < # key, query, value projections for all heads, but in a batch
40
+ < self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
41
+ < # output projection
42
+ < self.c_proj = nn.Linear(config.n_embd, config.n_embd)
43
+ < self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
44
+ < # regularization
45
+ 61c85,92
46
+ < # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
47
+ ---
48
+ > self.grouped_heads = config.grouped_heads
49
+ > self.head_dim = config.n_embd // config.n_head
50
+ >
51
+ > self.c_attn = nn.Linear(config.n_embd, (2 * config.n_head + self.grouped_heads) * self.head_dim)
52
+ > self.c_proj = nn.Linear(config.n_embd, config.n_embd)
53
+ >
54
+ > self.rotary_embedding = RotaryEmbedding(self.head_dim)
55
+ >
56
+ 63c94
57
+ < .view(1, 1, config.block_size, config.block_size))
58
+ ---
59
+ > .view(1, 1, config.block_size, config.block_size))
60
+ 66,67c97
61
+ < B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
62
+ < # calculate query, key, values for all heads in batch and move head forward to be the batch dim
63
+ ---
64
+ > B, T, C = x.size()
65
+ 69,84c99,120
66
+ < q, k, v = qkv.split(self.n_embd, dim=2)
67
+ < k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68
+ < q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
69
+ < v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
70
+ < if FLASH:
71
+ < # flashattention
72
+ < y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
73
+ < else:
74
+ < # manual implementation of attention
75
+ < # this materializes the large (T,T) matrix for all the queries and keys
76
+ < att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
77
+ < att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
78
+ < att = F.softmax(att, dim=-1)
79
+ < y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
80
+ < y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
81
+ < # output projection
82
+ ---
83
+ > q, k, v = torch.split(qkv, [self.grouped_heads * self.head_dim, self.n_head * self.head_dim, self.n_head * self.head_dim], dim=2)
84
+ >
85
+ > # Reshape for multi-head attention
86
+ > q = q.view(B, T, self.grouped_heads, self.head_dim).transpose(1, 2)
87
+ > k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
88
+ > v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
89
+ >
90
+ > # Apply RoPE
91
+ > sin, cos = self.rotary_embedding(T, x.device)
92
+ > q, k = apply_rotary_pos_emb(q, k, sin, cos)
93
+ >
94
+ > # Expand q to match the number of key/value heads
95
+ > q = q.repeat_interleave(self.n_head // self.grouped_heads, dim=1)
96
+ >
97
+ > # Attention computation
98
+ > att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
99
+ > att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
100
+ > att = F.softmax(att, dim=-1)
101
+ > y = att @ v
102
+ >
103
+ > # Reshape output
104
+ > y = y.transpose(1, 2).contiguous().view(B, T, C)
105
+ 87c123,124
106
+ <
107
+ ---
108
+ >
109
+ >
110
+ 89d125
111
+ <
112
+ 92,95c128,130
113
+ < self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
114
+ < self.gelu = NewGELU()
115
+ < self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
116
+ < self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
117
+ ---
118
+ > self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
119
+ > self.gelu = NewGELU() # Using GeLU activation
120
+ > self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
121
+ 99c134
122
+ < x = self.gelu(x)
123
+ ---
124
+ > x = self.gelu(x) # GeLU activation
125
+ 104d138
126
+ <
127
+ 117,119d150
128
+ < # -----------------------------------------------------------------------------
129
+ < # The main GPT-2 model
130
+ <
131
+ 122c153
132
+ < block_size: int = 1024
133
+ ---
134
+ > block_size: int = 2048
135
+ 124,126c155,158
136
+ < n_layer: int = 12
137
+ < n_head: int = 12
138
+ < n_embd: int = 768
139
+ ---
140
+ > n_layer: int = 16
141
+ > n_head: int = 16
142
+ > grouped_heads: int = 4 # Number of grouped heads for GQA
143
+ > n_embd: int = 1024
144
+ 129d160
145
+ <
146
+ 135,136c166
147
+ < wte = nn.Embedding(config.vocab_size, config.n_embd),
148
+ < wpe = nn.Embedding(config.block_size, config.n_embd),
149
+ ---
150
+ > wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embedding only, no wpe
151
+ 140,142d169
152
+ < self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
153
+ < self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights
154
+ < self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
155
+ 144,160c171,173
156
+ < # init all weights, use a torch rng object to be very careful
157
+ < self.init_rng = torch.Generator()
158
+ < self.init_rng.manual_seed(42)
159
+ < self.apply(self._init_weights)
160
+ <
161
+ < def _init_weights(self, module):
162
+ < if isinstance(module, nn.Linear):
163
+ < # apply special scaled init to the residual projections, per GPT-2 paper
164
+ < std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer)
165
+ < # we want to skip initializing lm_head, which shares parameters with wte
166
+ < # and wte was already initialized down below during the Embedding init
167
+ < if not hasattr(module, 'LLMC_SKIP_INIT'):
168
+ < torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng)
169
+ < if module.bias is not None:
170
+ < torch.nn.init.zeros_(module.bias)
171
+ < elif isinstance(module, nn.Embedding):
172
+ < torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng)
173
+ ---
174
+ > self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
175
+ > self.lm_head.LLMC_SKIP_INIT = 1 # Weight tying
176
+ > self.transformer.wte.weight = self.lm_head.weight
177
+ 166d178
178
+ < pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
179
+ 168,171c180,182
180
+ < # forward the GPT model itself
181
+ < tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
182
+ < pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
183
+ < x = tok_emb + pos_emb
184
+ ---
185
+ > # forward GPT model
186
+ > tok_emb = self.transformer.wte(idx) # token embeddings
187
+ > x = tok_emb
188
+ 176a188,189
189
+ > logits = self.lm_head(x)
190
+ >
191
+ 178,179d190
192
+ < # if we are given some desired targets also calculate the loss
193
+ < logits = self.lm_head(x)
194
+ 182,183d192
195
+ < # inference-time mini-optimization: only forward the lm_head on the very last position
196
+ < logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
197
+ 186c195
198
+ < # there are performance reasons why not returning logits is prudent, if not needed
199
+ ---
200
+ > # if return_logits is False, return only the loss (used for training)
201
+ 188c197
202
+ < logits = None
203
+ ---
204
+ > return None, loss
205
+ 189a199
206
+ > # return logits and optionally the loss (used for inference and training)
207
+ 201,204c211,214
208
+ < 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
209
+ < 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
210
+ < 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
211
+ < 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
212
+ ---
213
+ > 'gpt2': dict(n_layer=12, n_head=12, grouped_heads=4, n_embd=768), # 124M params
214
+ > 'gpt2-medium': dict(n_layer=24, n_head=16, grouped_heads=8, n_embd=1024), # 350M params
215
+ > 'gpt2-large': dict(n_layer=36, n_head=20, grouped_heads=10, n_embd=1280), # 774M params
216
+ > 'gpt2-xl': dict(n_layer=48, n_head=25, grouped_heads=12, n_embd=1600), # 1558M params
217
+ 298a309
218
+ >
219
+ 378a390,407
220
+ > def get_most_likely_row(tokens, mask, logits):
221
+ > # evaluate the autoregressive loss at all positions
222
+ > shift_logits = (logits[..., :-1, :]).contiguous()
223
+ > shift_tokens = (tokens[..., 1:]).contiguous()
224
+ > flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
225
+ > flat_shift_tokens = shift_tokens.view(-1)
226
+ > shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
227
+ > shift_losses = shift_losses.view(tokens.size(0), -1)
228
+ > # now get the average loss just for the completion region (where mask == 1), in each row
229
+ > shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
230
+ > masked_shift_losses = shift_losses * shift_mask
231
+ > # sum and divide by the number of 1s in the mask
232
+ > sum_loss = masked_shift_losses.sum(dim=1)
233
+ > avg_loss = sum_loss / shift_mask.sum(dim=1)
234
+ > # now we have a loss for each of the 4 completions
235
+ > # the one with the lowest loss should be the most likely
236
+ > pred_norm = avg_loss.argmin().item()
237
+ > return pred_norm
238
+ 655c684
239
+ < "d12": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768),
240
+ ---
241
+ > "d12": GPTConfig(block_size=2024, vocab_size=50257, n_layer=12, n_head=12, n_embd=1024),
242
+ 702,705d730
243
+ < # -------------------------------------------------------------------------
244
+ < # main training loop
245
+ <
246
+ < # here we wrap model into DDP container
247
+ 738a764,765
248
+ >
249
+ >
250
+ 758c785,786
251
+ < _, loss = model(x, y, return_logits=False)
252
+ ---
253
+ > logits, loss = model(x, y)
254
+ > print(logits.shape)
255
+ 782a811,853
256
+ >
257
+ > if step in [50,5000,10000,15000,19560]:
258
+ > save_path = f"{args.output_dir}/model_checkpoint_{step}.bin"
259
+ > torch.save(model.state_dict(), save_path)
260
+ > print0(f"Model saved at step {step} to {save_path}")
261
+ >
262
+ > if (step % 250 == 0 or last_step or step == 10): #and (not use_compile):
263
+ > num_correct_norm = 0
264
+ > num_total = 0
265
+ > for i, example in enumerate(iterate_examples("val")):
266
+ > # only process examples where i % ddp_world_size == ddp_rank
267
+ > if i % ddp_world_size != ddp_rank:
268
+ > continue
269
+ > # render the example into tokens and labels
270
+ > _, tokens, mask, label = render_example(example)
271
+ > tokens = tokens.to(device)
272
+ > mask = mask.to(device)
273
+ > # get the logits
274
+ > with torch.no_grad():
275
+ >
276
+ > with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
277
+ > logits, loss = model(tokens)
278
+ >
279
+ > # print(f"Step {step}:")
280
+ > # print(f"tokens shape: {tokens.shape}")
281
+ > # print(f"mask shape: {mask.shape}")
282
+ > # print(f"logits shape: {logits.shape}")
283
+ > pred_norm = get_most_likely_row(tokens, mask, logits)
284
+ > num_total += 1
285
+ > num_correct_norm += int(pred_norm == label)
286
+ > # reduce the stats across all processes
287
+ > if ddp:
288
+ > num_total = torch.tensor(num_total, dtype=torch.long, device=device)
289
+ > num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
290
+ > dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
291
+ > dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
292
+ > num_total = num_total.item()
293
+ > num_correct_norm = num_correct_norm.item()
294
+ > acc_norm = num_correct_norm / num_total
295
+ > if master_process:
296
+ > print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
297
+ > with open(logfile, "a") as f:
298
+ > f.write(f"{step} hella {acc_norm:.4f}\n")